# Environment Settings
# These settings specify the hardware and distributed setup for the model training.
# Adjust `num_gpus` and `dist_config` according to your distributed training environment.
env_setting:
  num_gpus: 1  # Number of GPUs. Now we don't support CPU mode. 
  num_workers: 8  # Number of worker threads for data loading.
  seed: 1234  # Seed for random number generators to ensure reproducibility.
  stdout_interval: 10
  checkpoint_interval: 1000  # save model to ckpt every N steps
  retain_one_checkpoint: True  # Retain only the recent checkpoint.
  validation_interval: 1000
  summary_interval: 100
  dist_cfg:
    dist_backend: nccl  # Distributed training backend, 'nccl' for NVIDIA GPUs.
    dist_url: tcp://localhost:19477  # URL for initializing distributed training.
    world_size: 1  # Total number of processes in the distributed training.

# Datapath Configuratoin
data_cfg:
  train_clean_json: data/train_clean.json
  train_noisy_json: data/train_noisy.json
  valid_clean_json: data/valid_clean.json
  valid_noisy_json: data/valid_noisy.json
  test_clean_json: data/test_clean.json
  test_noisy_json: data/test_noisy.json

# Training Configuration
# This section details parameters that directly influence the training process,
# including batch sizes, learning rates, and optimizer specifics.
training_cfg:
  training_epochs: 350 # Training epoch.
  batch_size: 4  # Training batch size.
  learning_rate: 0.0005  # Initial learning rate.
  adam_b1: 0.8  # Beta1 hyperparameter for the AdamW optimizer.
  adam_b2: 0.99  # Beta2 hyperparameter for the AdamW optimizer.
  lr_decay: 0.99  # Learning rate decay per epoch.
  segment_size: 32000  # Audio segment size used during training, dependent on sampling rate.
  loss:
    metric: 0.05
    magnitude: 0.9
    phase: 0.3
    complex: 0.1
    time: 0.2
    consistancy: 0.1
  use_PCS400: False  # Use PCS or not
  data_normalization: True  # Normalize the input data or not

# STFT Configuration
# Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
stft_cfg:
  sampling_rate: 16000  # Audio sampling rate in Hz.
  n_fft: 400  # FFT components for transforming audio signals.
  hop_size: 100  # Samples between successive frames.
  win_size: 400  # Window size used in FFT.

# Model Configuration
# Defines the architecture specifics of the model, including layer configurations and feature compression.
model_cfg:
  model_type: SEMambaCoDe2dReAt  # Model type
  mamba_version: 2  # Version of the Mamba model.
  hid_feature: 128  # Channels in dense layers. In Mamba2 hid_feature * expand / headdim = multiple of 8 https://github.com/state-spaces/mamba/issues/351#issuecomment-2169196817 Default headdim = 64
  compress_factor: 1  # Compression factor applied to extracted features.
  num_tfmamba: 8  # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
  d_state: 16  # Dimensionality of the state vector in Mamba blocks.
  d_conv: 4  # Convolutional layer dimensionality within Mamba blocks.
  expand: 4  # Expansion factor for the layers within the Mamba blocks.
  norm_epsilon: 0.00001  # Numerical stability in normalization layers within the Mamba blocks.
  beta: 2.0  # Hyperparameter for the Learnable Sigmoid function.
  attention_reduce_out_channels: 32  # Output channels for the 2dConv-based layer before attention mechanism to reduce dimension.
  attention_reduce_kernel_size: [1, 51] # Kernel size for the 2dConv-based layer before attention mechanism to reduce dimension. F before (n_fft//2 + 1)//2 == 100 -> F = F/2
  attention_reduce_padding: [0, 0] # Padding for the 2dConv-based layer before attention mechanism to reduce dimension.
  attention_reduce_stride: [1, 1] # Stride for the 2dConv-based layer before attention mechanism to reduce dimension.
  attention_reduce_group_size: 1  # Group size for the 2dConv-based layer before attention mechanism to reduce dimension.
  attention_embed_dim: 1600  # Input dimension for the attention mechanism after the projection-2dConv-based layer. before [B, hid_feature, T, F] -> after based on the conv2d layer
  attention_num_heads: 10  # Number of attention heads in the attention mechanism. embed_dim // num_heads
  attention_positional_enccoding_len: 1500  # Length of positional encoding. None for no positional encoding. 1500 same as in Whisper
  attention_expand_kernel_size: [1, 51] # Kernel size for the 2dConv-based layer after attention mechanism to expand dimension.
  attention_expand_padding: [0, 0] # Padding for the 2dConv-based layer after attention mechanism to expand dimension.
  attention_expand_stride: [1, 1] # Stride for the 2dConv-based layer after attention mechanism to expand dimension.
  attention_expand_group_size: 1  # Group size for the 2dConv-based layer after attention mechanism to expand dimension.
  input_channel: 2 # Magnitude and Phase
  output_channel: 1  # Single Channel Speech Enhancement


# Room Impulse Response Generator Configuration
# Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
rir_cfg:
  rir_files_path: data/sim_rir_16k
  type: RIR # Type of simulation, "RIR" for Room Impulse Response. "PyRoom" for Pyroomacoustics simulation.
  sampling_rate: 16000  # Audio sampling rate in Hz.
  reverberation_times: [0.15, 0.175, 0.2, 0.225, 0.25]  # reverberation times to simulate.
  rir_samples: 512
  hp_filter: True  # high-pass filter.
  version: 5  # simulation version.
