from yacs.config import CfgNode as CN

_C = CN(new_allowed=True)

###############
# Transformer #
###############
_C.Transformer = CN(new_allowed=True)
_C.Transformer.EPOCH = 120
_C.Transformer.batch_size = 720
_C.Transformer.lr = 0.0005

##########
# WavLM #
##########
_C.WavLM_finetune = CN(new_allowed=True)
_C.WavLM_finetune.epoch = 15
_C.WavLM_finetune.batch_size = 4
_C.WavLM_finetune.accumulate_each_n_steps = 1
_C.WavLM_finetune.val_batch_size = 4
_C.WavLM_finetune.lr = 1e-5
_C.WavLM_finetune.seed = 100
_C.WavLM_finetune.clip_grad_value = 1.0
_C.WavLM_finetune.kl_gamma = 0.1
_C.WavLM_finetune.clip_grad = True
_C.WavLM_finetune.optimizer = "AdamW"  # 'AdamW' / 'sgd'
_C.WavLM_finetune.log_dir = "dummy_test"
_C.WavLM_finetune.weight_decay = 0.01
_C.WavLM_finetune.freeze_cnn = True
_C.WavLM_finetune.freeze_upstream = False
_C.WavLM_finetune.wandb_project = "dummy_test"
_C.WavLM_finetune.device_id = "0"
_C.WavLM_finetune.wandb_train_step_log_interval = 1  # each n step log loss to wandb from 0 card
_C.WavLM_finetune.wandb_val_epoch_interval = 1  # each n epoch run validation and log loss to wandb
_C.WavLM_finetune.ce_weights = [1.25, 0.85, 0.81, 1.28]
_C.WavLM_finetune.ds_size = None
_C.WavLM_finetune.resume = None
_C.WavLM_finetune.wandb_mode = "online"
