# ############################################################################
# Recipe for "direct" (speech -> scenario)
# "Intent" classification using SLURP Dataset.
# 18 Scenarios classes are present in SLURP (calendar, email ...)
# We encode input waveforms into features using a SSL encoder.
# The probing is done using time-pooling followed by a linear classifier.
############################################################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/SLURP/linear/discrete_ssl/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
# The SLURP dataset will be automatically downloaded in the specified folder
data_folder: !PLACEHOLDER
# data_folder_rirs: !ref <data_folder>
train_splits: ["train_real"]
csv_train: !ref <output_folder>/train-type=direct.csv
csv_valid: !ref <output_folder>/devel-type=direct.csv
csv_test: !ref <output_folder>/test-type=direct.csv
skip_prep: False

compute_cost: !name:speechbrain.nnet.losses.nll_loss
# URL for the encoder2 model, you can change to benchmark diffrenet models

### Configuration for  discrete SSL model
# ssl_model_type: hubert, wavlm, wav2vec2
# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large,  facebook/wav2vec2-large
ssl_model_type: hubert # hubert, wavml or wav2vec2
ssl_hub: facebook/hubert-large-ll60k
ssl_folder: !ref <save_folder>/ssl_checkpoint
kmeans_repo_id: speechbrain/SSL_Quantization
kmeans_cache_dir: !ref <save_folder>/kmeans_checkpoint
kmeans_dataset: LibriSpeech-100-360-500
freeze_ssl: True
freeze_feature_extractor: True
num_clusters: 1000


### Config for Tokenizer
# Layer number should be among the supported layers for discrete SSL models(kmenas  model should be available for that layer)
# ssl_layer_num: [3, 7, 12, 23]
# deduplicate: [False, False, False, False]
# bpe_tokenizer_path: [null , null,  null, null]
ssl_layer_num: [1, 3, 7, 12, 18, 23]
num_codebooks: 6
deduplicate: [False, False, False, False, False, False]
bpe_tokenizer_path: [null, null, null, null, null, null]
sample_rate: 16000

# Feature parameters

encoder_dim: 1024

# Training parameters
precision: fp32
number_of_epochs: 20
batch_size: 2
test_batch_size: 1
lr: 0.0002
# token_type: unigram # ["unigram", "bpe", "char"]
sorting: random
ckpt_interval_minutes: 15 # save checkpoint every N min

# Model parameters
emb_size: 128
dec_neurons: 512
output_neurons: 18 # index(eos/bos) = 0

# Dataloader options
train_dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True
   num_workers: 2  # 2 on linux but 0 works on windows
   drop_last: False

valid_dataloader_opts:
   batch_size: !ref <batch_size>

test_dataloader_opts:
   batch_size: !ref <test_batch_size>

# Decoding parameters
bos_index: 0
eos_index: 0
min_decode_ratio: 0.0
max_decode_ratio: 10.0
slu_beam_size: 80
eos_threshold: 1.5
temperature: 1.25

dataloader_opts:
   batch_size: !ref <batch_size>
   shuffle: True

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
   limit: !ref <number_of_epochs>

# Modules
tokenizer_config:
   SSL_layers: !ref <ssl_layer_num>
   deduplicates: !ref <deduplicate>
   bpe_tokenizers: !ref <bpe_tokenizer_path>

ssl_model: !apply:speechbrain.utils.hparams.choice
   value: !ref <ssl_model_type>
   choices:
      wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
         source: !ref <ssl_hub>
         output_norm: False
         freeze: !ref <freeze_ssl>
         freeze_feature_extractor: !ref <freeze_feature_extractor>
         output_all_hiddens: True
         save_path: !ref <ssl_folder>
      hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
         source: !ref <ssl_hub>
         output_norm: False
         freeze: !ref <freeze_ssl>
         freeze_feature_extractor: !ref <freeze_feature_extractor>
         output_all_hiddens: True
         save_path: !ref <ssl_folder>
      wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
         source: !ref <ssl_hub>
         output_norm: False
         freeze: !ref <freeze_ssl>
         freeze_feature_extractor: !ref <freeze_feature_extractor>
         output_all_hiddens: True
         save_path: !ref <ssl_folder>

codec: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL
   save_path: !ref <kmeans_cache_dir>
   ssl_model: !ref <ssl_model>
   kmeans_dataset: !ref <kmeans_dataset>
   kmeans_repo_id: !ref <kmeans_repo_id>
   num_clusters: !ref <num_clusters>

discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer
   num_codebooks: !ref <num_codebooks>
   vocab_size: !ref <num_clusters>
   emb_dim: !ref <encoder_dim>

attention_mlp: !new:custom_model.AttentionMLP
   input_dim: !ref <encoder_dim>
   hidden_dim: !ref <encoder_dim>

avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
   return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
   input_size: !ref <encoder_dim>
   n_neurons: 18
   bias: False

model: !new:torch.nn.ModuleList
   - [!ref <output_mlp>, !ref <discrete_embedding_layer>, !ref <attention_mlp>]

modules:
   avg_pool: !ref <avg_pool>
   output_mlp: !ref <output_mlp>
   attention_mlp: !ref <attention_mlp>
   codec: !ref <codec>
   discrete_embedding_layer: !ref <discrete_embedding_layer>

tokenizer: !new:sentencepiece.SentencePieceProcessor


error_stats: !name:speechbrain.utils.metric_stats.MetricStats
   metric: !name:speechbrain.nnet.losses.classification_error
      reduction: batch

model_opt_class: !name:torch.optim.Adam
   lr: !ref <lr>


lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
   initial_value: !ref <lr>
   improvement_threshold: 0.0025
   annealing_factor: 0.8
   patient: 0


checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
   checkpoints_dir: !ref <save_folder>
   recoverables:
      model: !ref <model>
      attention_mlp: !ref <attention_mlp>
      codec: !ref <codec>
      discrete_embedding_layer: !ref <discrete_embedding_layer>
      scheduler_model: !ref <lr_annealing_model>
      counter: !ref <epoch_counter>

log_softmax: !new:speechbrain.nnet.activations.Softmax
   apply_log: True

seq_cost: !name:speechbrain.nnet.losses.nll_loss
   label_smoothing: 0.1

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
   save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
   split_tokens: True
