from yacs.config import CfgNode as CN

_C = CN(new_allowed=True)

###############
# Transformer #
###############
_C.Transformer = CN(new_allowed=True)

_C.Transformer.num_encoder_layers = 24
_C.Transformer.embed_dim = 1024
_C.Transformer.ffn_embed_dim = 512
_C.Transformer.num_heads = 8
_C.Transformer.activation = "gelu"
_C.Transformer.dropout = 0.1
_C.Transformer.bias = True
_C.Transformer.normalize_before = True

# positional embeddings
_C.Transformer.conv_pos = 128
_C.Transformer.conv_pos_groups = 16

##########
# WavLM #
##########
_C.WavLM = CN(new_allowed=True)

# mainstream model
_C.WavLM.encoder_layers = 24
_C.WavLM.encoder_embed_dim = 1024
_C.WavLM.distribution_prediction = "from_24_layers_mhsa"  # "from_24_layers_rnn", "from_cnn", "single", "multiple", "from_12_transformer", "from_last", "from_24_layers_mhsa", "from_24_layers_rnn"
_C.WavLM.ffn_embed_dim = 4096
_C.WavLM.num_heads = 16
_C.WavLM.activation = "gelu"
_C.WavLM.dropout = 0.1
_C.WavLM.bias = True
_C.WavLM.normalize = True
_C.WavLM.normalize_before = True
_C.WavLM.relative_position_embedding = True
_C.WavLM.qk_norm = False  # query/key (QK) normalization

# positional embedding
_C.WavLM.conv_pos = 128
_C.WavLM.conv_pos_groups = 16

# bucket relative position embedding
_C.WavLM.num_buckets = 320
_C.WavLM.max_distance = 800
_C.WavLM.gru_rel_pos = True

# FinetuneWrapper
_C.WavLM.projector_dim = 256
_C.WavLM.num_classes = 4
_C.WavLM.output_rep = "elbo"  # 'weighted_sum' / 'last_layer' / 'elbo' / 'layer_1' / 'weighted_hiddens' / 'mhfa'
_C.WavLM.deep_model = "simple"
_C.WavLM.deep_model_kernel_size = 5
_C.WavLM.deep_model_pooling = 5
_C.WavLM.deep_model_padding = 2
_C.WavLM.deep_model_dropout = 0.4
_C.WavLM.prior_distribution = (
    "weighted_sum_inference"  # "geometric"  # "eps_degenerated"  # 'asr_from_wer' / 'asr_from_weights'
)
_C.WavLM.p_for_geometric_pmf = 0.5
_C.WavLM.upsample_rate = 320
_C.WavLM.module = "LSTM"

# ASR
_C.WavLM.n_asr_models = True
_C.WavLM.layer_position_encoding = False
_C.WavLM.asr_model = "linear"
_C.WavLM.distribution_prediction_architecture = "sequential"
_C.WavLM.elbo_use_n_last_outputs = 24

# initiliaze with WavLM
_C.WavLM.init_with_wavlm = True
_C.WavLM.path_to_wavlm = "/app/data/wav-lm/WavLM-Large.pt"


# feature encoder
_C.WavLM.extractor_mode = "layer_norm"  # 'default' / 'layer_norm'
_C.WavLM.conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
_C.WavLM.dropout_input = 0.2
_C.WavLM.bidirection = True
_C.WavLM.rnn_dim = [1024, 1024]
_C.WavLM.rnn_dropout = [0.2, 0.2]
_C.WavLM.rnn_layer_norm = [False, False]
_C.WavLM.rnn_proj = [False, False]  # Linear projection + Tanh after each rnn layer
_C.WavLM.rnn_sample_rate = [1, 1]
_C.WavLM.rnn_total_rate = -1
_C.WavLM.sample_style = "concat"

# RNN distribution_prediction
_C.WavLM.rnn_hid_dim = 256
_C.WavLM.rnn_n_layers = 1
_C.WavLM.rnn_prediction_dropout = 0.0
_C.WavLM.d = 2  # 2 if biderectional else 1

# Attention Distribution Prediction
_C.WavLM.dist_att_n_num_heads = 16

# MHFA
_C.WavLM.mhfa_compression_dim = 128
_C.WavLM.mhfa_head_nb = 4

# KenLM decoder args
_C.WavLM.nbest = 1
_C.WavLM.unit_lm = False
_C.WavLM.criterion = "ctc"
_C.WavLM.beam = 5
_C.WavLM.beam_threshold = 20
_C.WavLM.word_score = -1
_C.WavLM.sil_weight = 0
_C.WavLM.chi2_df = 2
_C.WavLM.chi2_nc = 5
