# ############################################################################
# Recipe for "direct" (speech -> scenario)
# "Intent" classification using SLURP Dataset.
# 18 Scenarios classes are present in SLURP (calendar, email ...)
# We encode input waveforms into features using a SSL encoder.
# The probing is done using time-pooling followed by a linear classifier.
############################################################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/SLURP/linear/weighted_ssl/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
# The SLURP dataset will be automatically downloaded in the specified folder
data_folder: !PLACEHOLDER
# data_folder_rirs: !ref <data_folder>
train_splits: ["train_real"]
csv_train: !ref <output_folder>/train-type=direct.csv
csv_valid: !ref <output_folder>/devel-type=direct.csv
csv_test: !ref <output_folder>/test-type=direct.csv
skip_prep: False

compute_cost: !name:speechbrain.nnet.losses.nll_loss
# URL for the encoder2 model, you can change to benchmark diffrenet models

ssl_hub: microsoft/wavlm-large
ssl_folder: !ref <output_folder>/ssl_checkpoints
encoder_dim: 1024

# Training parameters
precision: fp32
number_of_epochs: 20
batch_size: 2
test_batch_size: 1
lr: 0.0002
lr_weights: 0.01
# token_type: unigram # ["unigram", "bpe", "char"]
sorting: random
ckpt_interval_minutes: 15 # save checkpoint every N min

# Model parameters
sample_rate: 16000
emb_size: 128
dec_neurons: 512
output_neurons: 18 # index(eos/bos) = 0

# Dataloader options
train_dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True
   num_workers: 2  # 2 on linux but 0 works on windows
   drop_last: False

valid_dataloader_opts:
   batch_size: !ref <batch_size>

test_dataloader_opts:
   batch_size: !ref <test_batch_size>

# Decoding parameters
bos_index: 0
eos_index: 0
min_decode_ratio: 0.0
max_decode_ratio: 10.0
slu_beam_size: 80
eos_threshold: 1.5
temperature: 1.25

dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
   limit: !ref <number_of_epochs>

# Models
weighted_ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.WeightedSSLModel # yamllint disable-line rule:line-length
   hub: !ref <ssl_hub>
   save_path: !ref <ssl_folder>

avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
   return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
   input_size: !ref <encoder_dim>
   n_neurons: 18
   bias: False

model: !new:torch.nn.ModuleList
   - [!ref <output_mlp>]

modules:
   avg_pool: !ref <avg_pool>
   output_mlp: !ref <output_mlp>
   weighted_ssl_model: !ref <weighted_ssl_model>

tokenizer: !new:sentencepiece.SentencePieceProcessor


error_stats: !name:speechbrain.utils.metric_stats.MetricStats
   metric: !name:speechbrain.nnet.losses.classification_error
      reduction: batch

model_opt_class: !name:torch.optim.Adam
   lr: !ref <lr>

weights_opt_class: !name:torch.optim.Adam
   lr: !ref <lr_weights>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
   initial_value: !ref <lr>
   improvement_threshold: 0.0025
   annealing_factor: 0.8
   patient: 0

lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler
   initial_value: !ref <lr_weights>
   improvement_threshold: 0.0025
   annealing_factor: 0.9

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
   checkpoints_dir: !ref <save_folder>
   recoverables:
      model: !ref <model>
      ssl_model: !ref <weighted_ssl_model>
      scheduler_model: !ref <lr_annealing_model>
      scheduler_encoder: !ref <lr_annealing_weights>
      counter: !ref <epoch_counter>

log_softmax: !new:speechbrain.nnet.activations.Softmax
   apply_log: True

seq_cost: !name:speechbrain.nnet.losses.nll_loss
   label_smoothing: 0.1

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
   save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
   split_tokens: True
