notes: "basic model = offline + 8k&16k" ### Insert schanges(plz write details !!!)
# ------------------------------------------------------------------------------------------------------------------------------ #
config:
    #!!!!!!!!!!!!!!!!!!!!!!!!!! train_phase and dataset_phase must be matched!!!!!!!!!!!!!!!!!!!!!!!#
    train_phase: &var_tr_phase "adversarial" # pretrain / adversarial
    dataset_phase: "to48k"
    train_phase_list: ['pretrain_to48k', 'adversarial_to48k']
    # -------------------------------------- Dataset --------------------------------------- #
    dataset:
        to48k:
            max_len : 3 #! audio length (originally set to 4s, temporally reduced due to memory)
            sample_rate_src: 48000
            sample_rate_in: 16000
            scp_dir: "data/scp/scp_VCTK"
            train:
                spk: "tr_s.scp"
                noise: "tr_n.scp" ## noise list does not matches to spks
            valid:
                spk: "cv_s.scp"
                noise: "tr_n.scp"  ## noise list does not matches to spks
            rir: "DNS_48K"  # RIR alias - actual path defined in .env as RIR_DNS_48K
        sythesis_config:
            multi_spk_prob: 0.2
            rir:
                prob: 0.5
                rir_sidelobe: 1 #! mili second(ms) sidelobe for target RIR synthesis e.g. 1ms = 16 samples at 16kHz
            noise:
                SNR_range : [0, 20]
                c_SNR_range : [0, 20]
                c_beta_range : [0.5, 1.5]
            BPF:
                prob: 0.5
                fir_filter_beta : [0.25, 1.0]
                low_cutoff_freq_range : [1000, 3000]
            clipping:
                prob: 0.5
                clipping_level_range : [-15,0]
            level:
                target_dB_FS : [-35, -15]
    dataset_test:
        testset_key: "VCTK_SR" #! VCTK_DEMAND / UNIVERSE / VoxCeleb
        tensorboard_logging: True
        input_eval: True #! if False, only enhanced audio is evaluated
        output_eval: True #! if False, only noisy audio is evaluated
        VCTK_DEMAND: # simple denoising subtask
            metrics: ['pesq', 'stoi', 'lsd', 'sdr', 'mcd', 
                      'wvmos', 'utmos', 'dnsmos', 
                      'bleu', 'bertscore', 'tokendist']
            sample_rate_src: 16000
            sample_rate_in: 16000
            clean_dir: "${VCTK_DEMAND_DB_ROOT}/clean_testset_wav"
            noisy_dir: "${VCTK_DEMAND_DB_ROOT}/noisy_testset_wav"
            random_sample_idx: [10, 20, 30, 40, 50]
        VCTK_SR: # simple super-resolution subtask
            metrics: ['lsd', 'mcd', 'utmos']
            sample_rate_src: 44100
            sample_rate_in: 8000
            clean_dir: "${VCTK_CORPUS_DB_ROOT}/wav48_silence_trimmed/test"
            noisy_dir: "${VCTK_CORPUS_DB_ROOT}/wav48_silence_trimmed/test"
            random_sample_idx: [10, 20, 30, 40, 50]
        UNIVERSE: # general speech restoration
            metrics: ['pesq', 'stoi', 'lsd', 'sdr', 'mcd',
                      'wvmos', 'utmos', 'dnsmos', 
                      'bleu', 'bertscore', 'tokendist']
            sample_rate_src: 16000
            sample_rate_in: 16000
            clean_dir: "${UNIVERSE_DB_ROOT}/target"
            noisy_dir: "${UNIVERSE_DB_ROOT}/input"
            random_sample_idx: [10, 20, 30, 40, 50]
        VoxCeleb: # Real datatest
            metrics: ['wvmos', 'utmos', 'dnsmos']
            sample_rate_src: 48000
            sample_rate_in: 16000 # None means that the model can process various input sampling rates / # ! Must be none not None. 
            clean_dir: null
            noisy_dir: "${VOXCELEB_SAMPLE_DB_ROOT}"
            random_sample_idx: [10, 20, 30, 40, 50]
    # ------------------------------------------------------------ #
    dataloader:
        batch_size: 2
        pin_memory: false
        num_workers: 0
        drop_last: false
    # ------------------------------------------------------------ #
    fs_list: &var_fs_list ['8000', '16000', '22050', '24000', '32000','44100', '48000'] #! can be directly processed by the model w/o resampling 
    stft: #! base length is based on 16kHz
        frame_length: 40 #! ms, not sample length
        frame_shift:  20 # ! ms, not sample length
    # ------------------------------------------- Model ------------------------------------------- #
    model:
        online: &var_online True #! new feature: mamba-based time module when online else TF-Locoformer-based one
        input_embedding:
            online: *var_online
            d_model: &var_model_channels_1st 128
            d_freq: &num_frequency_bin_max 961 #! win/2+1 at 48kHz = 1920/2+1
            freq_pe: True
        freq_linear:
            seq_len: *num_frequency_bin_max
            proj_len: 512
            n_heads: &var_n_head 4
            kv_shared: True
        # -------------------------- Encoder ------------------------- #
        encoder_stage:
            block_type: 'Encoder'
            RoPE:
                d_model: *var_model_channels_1st
                n_head: *var_n_head
                theta: &var_rope_theta 10000
            TF_block_Stage:
                online: *var_online
                time_module:
                    offline:
                        d_model: *var_model_channels_1st
                        d_hidden: &var_model_hidden_1st 384
                        n_head: *var_n_head
                        kernel_size: &var_kernel_size 7
                        dropout_rate: 0.00
                    online:
                        d_model: *var_model_channels_1st
                        d_state: &var_d_state 16
                        d_conv: &var_kernel_mamba 3
                        expand: &var_expand_mamba 4
                        dropout_rate: 0.00
                freq_module:
                    d_model: *var_model_channels_1st
                    d_hidden: *var_model_hidden_1st
                    n_head: *var_n_head
                    kernel_size: *var_kernel_size
                    dropout_rate: 0.00
            num_repeat: 6
        # ------------ Projection & Extension Query Padding ------------ #
        freq_upsampler:
            d_model: *var_model_channels_1st
            d_model_out: &var_model_channels_2nd 64
            d_freq_min: 161 #! win/2+1 at 8kHz = 320/2+1
            d_freq_max: *num_frequency_bin_max
        # -------------------------- Decoder ------------------------- #
        decoder_stage:
            block_type: 'Decoder'
            RoPE:
                d_model: *var_model_channels_2nd
                n_head: *var_n_head
                theta: *var_rope_theta
            TF_block_Stage:
                online: *var_online
                time_module:
                    offline:
                        d_model: *var_model_channels_2nd
                        d_hidden: &var_model_hidden_2nd 192
                        n_head: *var_n_head
                        kernel_size: *var_kernel_size
                        dropout_rate: 0.00
                    online:
                        d_model: *var_model_channels_2nd
                        d_state: *var_d_state
                        d_conv: *var_kernel_mamba
                        expand: *var_expand_mamba
                        dropout_rate: 0.00
                freq_module:
                    d_model: *var_model_channels_2nd
                    d_model_kv: *var_model_channels_1st
                    d_hidden: *var_model_hidden_2nd
                    n_head: *var_n_head
                    kernel_size: *var_kernel_size
                    dropout_rate: 0.00
            num_repeat: 3
        output_spec:
            online: *var_online
            d_model: *var_model_channels_2nd
    # -------------------------------------- Training ------------------------------------------ #
    engine:
        #! ------------ set to 1.0 when 8k only input is assumed ---------- !
        prob_effect:
            downsample_8k: 0.25
            codec: 0.3
            crystalizer: 0.15
            flanger: 0.05
            crusher: 0.1
        # -------------------------------- #
        subset:
            train:
                subset: true
                num_per_epoch: 20000 # number of data samples for one epoch
            valid:
                subset: true
                num_per_epoch: 2000 # number of data samples for one epoch
        # -------------------------------- #
        pretrain_to48k:
            downsample_src:
                prob: 0.6
                fs_list_src: [16000, 24000, 44100]
            loss_enhance:
                tau: 1.0e-4
                window_size: &var_win_48k [1920]
            loss_time:
                beta: 1.0e-3
            loss_rep:
                #! "facebook/wav2vec2-large-xlsr-53", "facebook/wav2vec2-xls-r-300m" , "microsoft/wavlm-large", 
                model_key: &var_SSL_model_key "microsoft/wavlm-large"
                resampler:
                    orig_freq: 48000
                    new_freq: 16000
            loss_weight:
                se: 1.0
                time: 0.0
                ssl: 1.0e+2
            RandSpecMasking:
                t_len: [0, 10]
                f_len: [0, 10]
                t_num: [0, 2]
                f_num: [0, 3]
        adversarial_to48k:
            downsample_src:
                prob: 0.6
                fs_list_src: [16000, 24000, 44100]
            loss_enhance:
                tau: 1.0e-4
                window_size: *var_win_48k
            loss_time:
                beta: 1.0e-3
            loss_rep:
                model_key: *var_SSL_model_key
                resampler:
                    orig_freq: 48000
                    new_freq: 16000
            loss_weight:
                se: 1.0
                time: 0.0
                ssl: 1.0e+2
                gan: 1.0e-3
                fm: 0.1
                pesq: 1.0e-4
            RandSpecMasking:
                t_len: [0, 10]
                f_len: [0, 10]
                t_num: [0, 1]
                f_num: [0, 1]
            msstftd:
                filters: 32
                n_ffts_ms: &var_win_48k_disc [20, 40, 60, 80, 100] # in ms - window lengths
                fs_list: *var_fs_list
        sample_validation: ['data/valid_sample/UNIVERSE_sample/0015.wav',
                            'data/valid_sample/UNIVERSE_sample/0034.wav',
                            'data/valid_sample/UNIVERSE_sample/0061.wav',
                            'data/valid_sample/UNIVERSE_sample/0090.wav',
                            'data/valid_sample/UNIVERSE_sample/0096.wav',
                            'data/valid_sample/UNIVERSE_sample/0098.wav',
                            'data/valid_sample/Real_sample/0.wav',
                            'data/valid_sample/Real_sample/1.wav',
                            'data/valid_sample/Real_sample/2.wav',
                            'data/valid_sample/VoxCeleb_sample/f77-id10281-ni6gO5jDLJE-00010.wav',
                            'data/valid_sample/VoxCeleb_sample/f83-id10282-hgB5ziAudzU-00001.wav',
                            'data/valid_sample/VoxCeleb_sample/m87-id10271-PfcJLmkhGbk-00007.wav',
                            'data/valid_sample/VoxCeleb_sample/m89-id10304-jUSC4i_eGHs-00002.wav',
                            'data/valid_sample/VoxCeleb_sample/m99-id10297-FvbLoirHpx0-00006.wav',
                            ]
        # -------------------------------------------------------------------------------------------------- #
        optimizer:
            name: "AdamW" ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...]
            AdamW:
                lr: 2.0e-4
                betas: [0.9, 0.995]
                weight_decay: 1.0e-2
            Adam:
                lr: 1.0e-3
                weight_decay : 0.0
        # -------------------------------- #
        optimizer_D:
            name: "AdamW" ### Choose a torch.optim's class(=attribute) e.g. ["Adam", "AdamW", "SGD", ...]
            AdamW:
                lr: 2.0e-4
                betas: [0.8, 0.999]
                weight_decay: 1.0e-2
            Adam:
                lr: 1.0e-3
                weight_decay : 0.0
        # -------------------------------- #
        scheduler:
            name: "StepLR" # ReduceLROnPlateau, 
            WarmupConstantSchedule:
                warmup_steps: 10000
            ReduceLROnPlateau:
                mode: "min"
                min_lr: 1.0e-10
                factor: 0.5
                patience: 2
            StepLR:
                step_size: 1
                gamma: 0.9
        # -------------------------------------------------------------------------------------------------- #
        inference:
            alpha: 0.2
            max_iter: 3
        # -------------------------------- #
        max_epoch: 
            pretrain_to48k: 20
            adversarial_to48k: 20
        gpuid: "1"
        clip_norm: 10
        start_scheduling: 
            pretrain_to48k: 10
            adversarial_to48k: 1


