# #################################
# Basic training parameters for S-CNF
# #################################

seed: 929
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
device: cuda

scp_folder: ./data
output_folder: !ref ../test_<seed>_1
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

train_annotation: !ref <scp_folder>/Train.json
valid_annotation: !ref <scp_folder>/Validation.json
test_annotation: !ref <scp_folder>/Test.json

label_path: !ref <scp_folder>/lab.npy
label_path_maj: !ref <scp_folder>/lab-maj.npy
fea_path: !ref <scp_folder>/wav.npy


dataloader_options:
    batch_size: !ref <batch_size>
    shuffle: False
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padded_keys: sig
sorting: descending

acc_stats: !name:modules.MetricStats_Acc

# Training Parameters
ckpt_interval_minutes: 15
test_only: False
number_of_epochs: 40
batch_size: 64
lr: 1.2
dp:  0.2
gradient_accumulation: 1
num_samples: 50
num_elbo: 20

# model Parameters
input_dim: 768
output_dim: 5 # Emotion: 'neu':0,'hap':1,'sad':2,'ang':3,'oth':4
# output_dim: 3 # Hate: "normal":0, "offensive":1,"hatespeech":2

freeze_SSL: True
freeze_SSL_conv: True
SSL_hub: "microsoft/wavlm-base-plus"
num_pretrain_layers: 13 # Emotion
# num_pretrain_layers: 1  # Hate
nhead: 4
num_trans_encoder: 2
d_trans: 128
num_fc: 2
d_fc: 128

flow_num_block: 3
nvp_hidden_width: 64

softmax_num_block: 1
softmax_hidden_dim: 64

# upstream model
SSLModel: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
    source: !ref <SSL_hub>
    output_norm: False
    freeze: !ref <freeze_SSL>
    freeze_feature_extractor: !ref <freeze_SSL_conv>
    output_all_hiddens: True
    save_path: ../data/wav2vec2_checkpoint

# downstream model
Transformer_model: !new:modules.TransformerModel
    input_dim: !ref <input_dim>
    output_dim: !ref <output_dim>
    num_pretrain_layers: !ref <num_pretrain_layers>
    d_trans: !ref <d_trans>
    nhead: !ref <nhead>
    num_encoder_layers: !ref <num_trans_encoder>
    num_fc: !ref <num_fc>
    d_fc: !ref <d_fc>
    dp:  !ref <dp>
    device: !ref <device>

softmax_enc: !new:modules.linear_enc
    input_dim: !ref <output_dim>
    dnn_blocks: !ref <softmax_num_block>
    dnn_neurons: !ref <softmax_hidden_dim>
    output_dim: !ref <output_dim>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

modules:
    SSL_encoder: !ref <SSLModel>
    feature_extractor: !ref <Transformer_model>
    softmax_encoder: !ref <softmax_enc>

opt_class: !name:torch.optim.Adadelta
   lr: !ref <lr>
   rho: 0.95
   eps: 1.e-8

lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
   initial_value: !ref <lr>
   improvement_threshold: 0.001
   annealing_factor: 0.8
   patient: 1

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        Transformer_model: !ref <Transformer_model>
        softmax_encoder: !ref <softmax_enc>
        counter: !ref <epoch_counter>
        scheduler: !ref <lr_annealing>