# ############################################################################
# Recipe for "direct" (speech -> scenario) "Intent" classification
# using SLURP Dataset.
# 18 Scenarios classes are present in SLURP (calendar, email ...)
# We encode input waveforms into features using a discrete tokens
# The probing is done using a RNN layer followed by a linear classifier.
############################################################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/SLURP/encodec/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
# The SLURP dataset will be automatically downloaded in the specified folder
data_folder: !PLACEHOLDER
# data_folder_rirs: !ref <data_folder>
train_splits: ["train_real"]
csv_train: !ref <output_folder>/train-type=direct.csv
csv_valid: !ref <output_folder>/devel-type=direct.csv
csv_test: !ref <output_folder>/test-type=direct.csv
skip_prep: False

compute_cost: !name:speechbrain.nnet.losses.nll_loss

# Training parameters
precision: fp32
number_of_epochs: 20
batch_size: 2
test_batch_size: 1
lr: 0.0002
lr_weights: 0.01
# token_type: unigram # ["unigram", "bpe", "char"]
sorting: random
ckpt_interval_minutes: 5 # save checkpoint every N min

# Model parameters
output_neurons: 18 # index(eos/bos) = 0

### Config for Tokenizer
# EnCodec parameters
# sample_rate: [24000, 24000, 24000, 24000]
# vocab_size: [1024, 1024, 1024, 1024]
# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0]
# num_codebooks: [2, 4, 8, 16, 32]
vocab_size: 1024
bandwidth: 1.5
num_codebooks: 2
sample_rate: 24000
# Feature parameters
encoder_dim: 1024
# If set to True, the encoder_dim should be set to the dim of the tokenizer. For encodec it is 128.
init_embedding: False
freeze_embedding: False

# Dataloader options
train_dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True
   num_workers: 2  # 2 on linux but 0 works on windows
   drop_last: False

valid_dataloader_opts:
   batch_size: !ref <batch_size>

test_dataloader_opts:
   batch_size: !ref <test_batch_size>

# Modules
# EnCodec model (see https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/encodec)
codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec
   source: facebook/encodec_24khz  # Only the 24kHz version supports mono audio
   save_path: !ref <save_folder>
   sample_rate: !ref <sample_rate>
   bandwidth: !ref <bandwidth>
   flat_embeddings: False
   freeze: True
   renorm_embeddings: False

discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer
   num_codebooks: !ref <num_codebooks>
   vocab_size: !ref <vocab_size>
   emb_dim: !ref <encoder_dim>
   freeze: !ref <freeze_embedding>
   init: !ref <init_embedding>

attention_mlp: !new:custom_model.AttentionMLP
   input_dim: !ref <encoder_dim>
   hidden_dim: !ref <encoder_dim>

enc: !new:speechbrain.nnet.containers.Sequential
   input_shape: [null, null, !ref <encoder_dim>]
   lstm: !new:speechbrain.nnet.RNN.LSTM
      input_size: !ref <encoder_dim>
      bidirectional: True
      hidden_size: !ref <encoder_dim>
      num_layers: 2
   linear: !new:speechbrain.nnet.linear.Linear
      input_size: !ref <encoder_dim> * 2
      n_neurons: !ref <encoder_dim>

# Decoding parameters
bos_index: 0
eos_index: 0
min_decode_ratio: 0.0
max_decode_ratio: 10.0
slu_beam_size: 80
eos_threshold: 1.5
temperature: 1.25

dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
   limit: !ref <number_of_epochs>


avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
   return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
   input_size: !ref <encoder_dim>
   n_neurons: 18
   bias: False

modules:
   enc: !ref <enc>
   avg_pool: !ref <avg_pool>
   output_mlp: !ref <output_mlp>
   attention_mlp: !ref <attention_mlp>
   codec: !ref <codec>
   discrete_embedding_layer: !ref <discrete_embedding_layer>

model: !new:torch.nn.ModuleList
   - [!ref <enc>, !ref <output_mlp>, !ref <discrete_embedding_layer>, !ref <attention_mlp>]

tokenizer: !new:sentencepiece.SentencePieceProcessor

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
   metric: !name:speechbrain.nnet.losses.classification_error
      reduction: batch

model_opt_class: !name:torch.optim.Adam
   lr: !ref <lr>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
   initial_value: !ref <lr>
   improvement_threshold: 0.0025
   annealing_factor: 0.8
   patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
   checkpoints_dir: !ref <save_folder>
   recoverables:
      model: !ref <model>
      attention_mlp: !ref <attention_mlp>
      codec: !ref <codec>
      discrete_embedding_layer: !ref <discrete_embedding_layer>
      scheduler_model: !ref <lr_annealing_model>
      counter: !ref <epoch_counter>

log_softmax: !new:speechbrain.nnet.activations.Softmax
   apply_log: True

seq_cost: !name:speechbrain.nnet.losses.nll_loss
   label_smoothing: 0.1

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
   save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
   split_tokens: True
