# #################################
# Basic training parameters for I-CNF
# #################################

seed: 929
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
device: cuda

data_folder: PLACEHOLDER
scp_folder: ../data/
output_folder: !ref ../test_<seed>_1
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

train_annotation: !ref <scp_folder>/train_mos_list.json
valid_annotation: !ref <scp_folder>/valid_mos_list.json
test_annotation: !ref <scp_folder>/test_mos_list.json

label_file: !ref <scp_folder>/somos-lab-full.npy
fea_path: !ref <scp_folder>/somos-wav.npy

dataloader_options:
    batch_size: !ref <batch_size>
    shuffle: False
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padded_keys: sig
sorting: descending

error_stats_mse: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.mse_loss

# Training Parameters
ckpt_interval_minutes: 15 
test_only: False
number_of_epochs: 40
batch_size: 64
lr: 0.05
dp:  0.0
gradient_accumulation: 1
num_samples: 50

# model Parameters
input_dim: 768
output_dim: 1

freeze_SSL: True
freeze_SSL_conv: True
SSL_hub: "microsoft/wavlm-base-plus"
num_pretrain_layers: 13
nhead: 4
num_trans_encoder: 2
d_trans: 128
num_fc: 2
d_fc: 128

flow_num_block: 3
nvp_hidden_width: 16

# upstream model
SSLModel: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
    source: !ref <SSL_hub>
    output_norm: False
    freeze: !ref <freeze_SSL>
    freeze_feature_extractor: !ref <freeze_SSL_conv>
    output_all_hiddens: True
    save_path: ../data/wav2vec2_checkpoint

# downstream model
Transformer_model: !new:modules.TransformerModel
    input_dim: !ref <input_dim>
    output_dim: !ref <output_dim>
    num_pretrain_layers: !ref <num_pretrain_layers>
    d_trans: !ref <d_trans>
    nhead: !ref <nhead>
    num_encoder_layers: !ref <num_trans_encoder>
    num_fc: !ref <num_fc>
    d_fc: !ref <d_fc>
    dp:  !ref <dp>
    device: !ref <device>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

modules:
    SSL_encoder: !ref <SSLModel>
    feature_extractor: !ref <Transformer_model>

opt_class: !name:torch.optim.Adadelta
   lr: !ref <lr>
   rho: 0.95
   eps: 1.e-8

lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
   initial_value: !ref <lr>
   improvement_threshold: 0.001
   annealing_factor: 0.8
   patient: 1

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        Transformer_model: !ref <Transformer_model>
        counter: !ref <epoch_counter>
        scheduler: !ref <lr_annealing>