from yacs.config import CfgNode as CN

_C = CN(new_allowed=True)

###############
# Transformer #
###############
_C.Transformer = CN(new_allowed=True)

_C.Transformer.num_encoder_layers = 24
_C.Transformer.embed_dim = 1024
_C.Transformer.ffn_embed_dim = 512
_C.Transformer.num_heads = 8
_C.Transformer.activation = "gelu"
_C.Transformer.dropout = 0.1
_C.Transformer.bias = True
_C.Transformer.normalize_before = True

# positional embeddings
_C.Transformer.conv_pos = 128
_C.Transformer.conv_pos_groups = 16

##########
# WavLM #
##########
_C.WavLM = CN(new_allowed=True)

# rms-based mask
_C.WavLM.mask_depend_on_rms = True
_C.WavLM.frame_length = 400  # 16000 * 0.025
_C.WavLM.hop_length = 320  # 16000 * 0.020
_C.WavLM.span_space = 1
_C.WavLM.h_up = 1.0
_C.WavLM.h_down = 0.5
_C.WavLM.l_up = 0.49
_C.WavLM.l_down = 0.2
_C.WavLM.small_span = 8
_C.WavLM.num_small_span = 20
_C.WavLM.large_span = 40
_C.WavLM.num_large_span = 4
_C.WavLM.max_mask_percentage = 0.64

# mainstream model
_C.WavLM.encoder_layers = 24
_C.WavLM.encoder_embed_dim = 1024
_C.WavLM.distribution_prediction = (
    "from_24_layers_mhsa"  # "from_cnn", "single", "multiple", "from_12_transformer", "from_last"
)
_C.WavLM.ffn_embed_dim = 4096
_C.WavLM.num_heads = 16
_C.WavLM.dist_mlp = False
_C.WavLM.activation = "gelu"
_C.WavLM.dropout = 0.1
_C.WavLM.bias = True
_C.WavLM.normalize = True
_C.WavLM.normalize_before = True
_C.WavLM.relative_position_embedding = True
_C.WavLM.qk_norm = False  # query/key (QK) normalization

# FinetuneWrapper
_C.WavLM.projector_dim = 256
_C.WavLM.num_classes = 1240
_C.WavLM.output_rep = "weighted_sum"  # 'weighted_sum' / 'last_layer' / 'elbo' / 'layer_1'
_C.WavLM.deep_model = "simple"
_C.WavLM.deep_model_kernel_size = 5
_C.WavLM.deep_model_pooling = 5
_C.WavLM.deep_model_padding = 2
_C.WavLM.deep_model_dropout = 0.4
_C.WavLM.prior_distribution = "chi2"  # 'uniform' / 'geometric' / 'geometric_reverse' / 'second_half_growth' /
_C.WavLM.attention_pooling = False

# initiliaze with wavlm
_C.WavLM.init_with_wavlm = True
_C.WavLM.init_style = [
    "identity_mapping"
]  # ['identity_mapping']   # ['custom_average', [(0, 1), (2, 5), (6, 13), (14, 23)]], ['custom_extract', [0, 5, 11, 17]]
_C.WavLM.path_to_wavlm = "/app/data/wav-lm/WavLM-Large.pt"

# initiliaze with other pre-trained model
_C.WavLM.init_with_ckpt = False
_C.WavLM.path_to_vesper = None  # "/app/data/temp_ckpt/model_best_val.pt"

# positional embedding
_C.WavLM.conv_pos = 128
_C.WavLM.conv_pos_groups = 16

# bucket relative position embedding
_C.WavLM.num_buckets = 320
_C.WavLM.max_distance = 800
_C.WavLM.gru_rel_pos = True

# feature encoder
_C.WavLM.extractor_mode = "layer_norm"  # 'default' / 'layer_norm'
_C.WavLM.conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
_C.WavLM.dropout_input = 0.1

# distribution prediction
_C.WavLM.dist_att_n_num_heads = 16
_C.WavLM.d = 2
_C.WavLM.rnn_n_layers = 1
_C.WavLM.rnn_prediction_dropout = 0.0

#
_C.WavLM.chi2_df = 2
_C.WavLM.chi2_nc = 5
_C.WavLM.layer_used_for_inference = None
_C.WavLM.elbo_share_downstream_weights = True
