# ########################################
# Recipe for training an emotion recognition system from speech data
# only using IEMOCAP and an SSL feature extractor
# The system classifies 4 emotions ( anger, happiness, sadness, neutrality)
# with an ECAPA-TDNN model.
# ########################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Dataset will be downloaded to the `data_original`
data_folder: !PLACEHOLDER # e.g., /path/to/IEMOCAP_full_release
output_folder: !ref results/IEMOCAP/IEMOCAP_full_release/linear/speech_tokenizer/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# URL for the ssl encoder  model, you can change to benchmark diffrenet models
# Important: we use ssl encoder base and not the fine-tuned one with ASR task
# This allow you to have ~4% improvment

### Config for Tokenizer
vocab_size: 1024
num_codebooks: 2
sample_rate: 16000

# Feature parameters
encoder_dim: 1024

# different speakers for train, valid and test sets
different_speakers: True
# which speaker is used for test set, value from 1 to 10
# Change this value and run this value 10 times and take the mean of it
test_spk_id: 1

# Path where data manifest files will be stored
train_annotation: !ref <output_folder>/train.json
valid_annotation: !ref <output_folder>/valid.json
test_annotation: !ref <output_folder>/test.json
skip_prep: False

# The train logger writes training statistics to a file, as well as stdout.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

ckpt_interval_minutes: 15 # save checkpoint every N min

# Training parameters
precision: fp32
number_of_epochs: 30
batch_size: 2
test_batch_size: 1

lr: 0.0002

# Number of emotions
out_n_neurons: 4 # (anger, happiness, sadness, neutral)

# Dataloader options
train_dataloader_opts:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: 2  # 2 on linux but 0 works on windows
    drop_last: False

valid_dataloader_opts:
    batch_size: !ref <batch_size>

test_dataloader_opts:
    batch_size: !ref <test_batch_size>

# Modules
codec: !new:speechbrain.lobes.models.discrete.speechtokenizer_interface.SpeechTokenizer_interface
    source: fnlp/SpeechTokenizer  # Only the 24kHz version supports mono audio
    save_path: !ref <save_folder>

discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer
    num_codebooks: !ref <num_codebooks>
    vocab_size: !ref <vocab_size>
    emb_dim: !ref <encoder_dim>

attention_mlp: !new:custom_model.AttentionMLP
    input_dim: !ref <encoder_dim>
    hidden_dim: !ref <encoder_dim>

avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
    return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <encoder_dim>
    n_neurons: !ref <out_n_neurons>
    bias: False

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

modules:
    output_mlp: !ref <output_mlp>
    attention_mlp: !ref <attention_mlp>
    codec: !ref <codec>
    discrete_embedding_layer: !ref <discrete_embedding_layer>


model: !new:torch.nn.ModuleList
    - [!ref <output_mlp>, !ref <discrete_embedding_layer>, !ref <attention_mlp>]

log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

compute_cost: !name:speechbrain.nnet.losses.nll_loss

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

model_opt_class: !name:torch.optim.Adam
    lr: !ref <lr>


lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        scheduler_model: !ref <lr_annealing_model>
        attention_mlp: !ref <attention_mlp>
        codec: !ref <codec>
        discrete_embedding_layer: !ref <discrete_embedding_layer>
        counter: !ref <epoch_counter>
