# ############################################################################
# Recipe for "direct" (speech -> scenario) "Intent" classification
# using SLURP Dataset.
# 18 Scenarios classes are present in SLURP (calendar, email ...)
# We encode input waveforms into features using a discrete tokens
# The probing is done using a RNN layer followed by a linear classifier.
############################################################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/SLURP/dac/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
# The SLURP dataset will be automatically downloaded in the specified folder
data_folder: !PLACEHOLDER
# data_folder_rirs: !ref <data_folder>
train_splits: ["train_real"]
csv_train: !ref <output_folder>/train-type=direct.csv
csv_valid: !ref <output_folder>/devel-type=direct.csv
csv_test: !ref <output_folder>/test-type=direct.csv
skip_prep: False

compute_cost: !name:speechbrain.nnet.losses.nll_loss

# Training parameters
precision: fp32
number_of_epochs: 20
batch_size: 2
test_batch_size: 1
lr: 0.0002
# token_type: unigram # ["unigram", "bpe", "char"]
sorting: random
ckpt_interval_minutes: 5 # save checkpoint every N min

# Model parameters
output_neurons: 18 # index(eos/bos) = 0

### Config for Tokenizer
# DAC parameters
# model_type: [16khz, 24khz, 44khz, 44khz]
# vocab_size: [1024, 1024, 1024, 1024]
# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps]
# max_num_codebooks: [12, 32, 9, 18]
# embedding_dim: [1024, 1024, 1024, 128]
model_type: 24khz
vocab_size: 1024
model_bitrate: 8kbps
num_codebooks: 2  # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type
sample_rate: 24000
encoder_dim: 1024


# Dataloader options
train_dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True
   num_workers: 2  # 2 on linux but 0 works on windows
   drop_last: False

valid_dataloader_opts:
   batch_size: !ref <batch_size>

test_dataloader_opts:
   batch_size: !ref <test_batch_size>

# Modules
# DAC model (see https://github.com/descriptinc/descript-audio-codec)
codec: !new:speechbrain.lobes.models.discrete.dac.DAC
   model_type: !ref <model_type>
   model_bitrate: !ref <model_bitrate>
   load_pretrained: True
   tag: latest

discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer
   num_codebooks: !ref <num_codebooks>
   vocab_size: !ref <vocab_size>
   emb_dim: !ref <encoder_dim>

attention_mlp: !new:custom_model.AttentionMLP
   input_dim: !ref <encoder_dim>
   hidden_dim: !ref <encoder_dim>

enc: !new:speechbrain.nnet.containers.Sequential
   input_shape: [null, null, !ref <encoder_dim>]
   lstm: !new:speechbrain.nnet.RNN.LSTM
      input_size: !ref <encoder_dim>
      bidirectional: True
      hidden_size: !ref <encoder_dim>
      num_layers: 2
   linear: !new:speechbrain.nnet.linear.Linear
      input_size: !ref <encoder_dim> * 2
      n_neurons: !ref <encoder_dim>

# Decoding parameters
bos_index: 0
eos_index: 0
min_decode_ratio: 0.0
max_decode_ratio: 10.0
slu_beam_size: 80
eos_threshold: 1.5
temperature: 1.25

dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
   limit: !ref <number_of_epochs>

avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
   return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
   input_size: !ref <encoder_dim>
   n_neurons: 18
   bias: False

modules:
   enc: !ref <enc>
   avg_pool: !ref <avg_pool>
   output_mlp: !ref <output_mlp>
   attention_mlp: !ref <attention_mlp>
   codec: !ref <codec>
   discrete_embedding_layer: !ref <discrete_embedding_layer>

model: !new:torch.nn.ModuleList
   - [!ref <enc>, !ref <output_mlp>, !ref <discrete_embedding_layer>, !ref <attention_mlp>]

tokenizer: !new:sentencepiece.SentencePieceProcessor

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
   metric: !name:speechbrain.nnet.losses.classification_error
      reduction: batch

model_opt_class: !name:torch.optim.Adam
   lr: !ref <lr>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
   initial_value: !ref <lr>
   improvement_threshold: 0.0025
   annealing_factor: 0.8
   patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
   checkpoints_dir: !ref <save_folder>
   recoverables:
      model: !ref <model>
      attention_mlp: !ref <attention_mlp>
      codec: !ref <codec>
      discrete_embedding_layer: !ref <discrete_embedding_layer>
      scheduler_model: !ref <lr_annealing_model>
      counter: !ref <epoch_counter>

log_softmax: !new:speechbrain.nnet.activations.Softmax
   apply_log: True

seq_cost: !name:speechbrain.nnet.losses.nll_loss
   label_smoothing: 0.1

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
   save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
   split_tokens: True
