
# --- Model Parameters (for NeuralToPhonemeTransformer) ---
model_params:
  ecog_n_channels: 256      # IMPORTANT: Number of ECoG channels in your input data
  d_model: 512              # Dimensionality of the model
  n_head: 8                 # Number of attention heads
  num_encoder_layers: 6     # Number of layers in the Transformer encoder
  num_decoder_layers: 6     # Number of layers in the Transformer decoder
  dim_feedforward: 2048     # Dimension of the feedforward network in Transformer layers
  dropout: 0.4              # Dropout rate
  transformer_activation: 'gelu' # Activation function in Transformer
  vocab_size: 43            # IMPORTANT: Total phoneme tokens (incl. BOS, EOS, PAD)
  pad_idx: 0                
  eos_idx: 42
  sos_idx: 41

  # encoder_implementation: 'loop'
  # aux_head_input_layer_idx: 1
  # mfcc_aux_head_type: 'mlp' # "linear" or "mlp"

  # --- SpecAugment-style Masking ---
  time_masking_prob: 0.30
  time_mask_max_len: 25              # Global max mask length in time steps
  time_mask_max_proportion: 0.2      # Mask at most 20% of a sequence's length

  channel_masking_prob: 0.25
  channel_mask_max_electrodes: 15    # Max number of physical electrodes

  # Day transformation (optional - set use_day_transform: false if not using)
  use_day_transform: true
  day_transform_type: "LightFiLM" # "FC", "FiLM", "LightFiLM"
  num_days: 24              # Example: total number of unique days/sessions if use_day_transform is true

  # --- Downsampling Settings ---
  downsampling_strategy: 'conv'                  # 'conv','unfold' or 'none'
  downsampling_factor: 4                         # Unified key for the downsampling factor
  
  # feature extractor
  feature_extractor_type: "binned_attention_conv_downsample" # "FC", "binned_conv", "binned_attention_conv", "binned_attention_conv_downsample", "Deep_conv" or "Interpretable_conv"
  feature_extractor_kernel_size: 5       # Kernel size for initial conv. Should be >= downsample_factor.

  # --- Main Phoneme Decoder Loss Weight (NEW) ---
  phoneme_main_loss_weight: 1.0 # Weight for the main phoneme decoder's loss

  # --- Auxiliary Task Parameters ---
  train_mfcc_aux: false
  num_mfcc_features: 14     
  mfcc_loss_weight: 0.005

  train_envelope_aux: false
  envelope_loss_weight: 0.05

  train_phoneme_len_aux: false
  phoneme_len_loss_weight: 0.05
  use_predicted_len_for_sg: false

  train_word_count_aux: false
  word_count_loss_weight: 0.05

  # Control inclusion of auxiliary losses in the main training loss
  train_aux_in_joint_tf: true
  train_aux_in_joint_sg: false

  # --- Freezing Parameters for Phoneme Path ---
  freeze_encoder_in_joint: false
  freeze_decoder_in_joint: false

  # Validation settings for the main phoneme task
  max_gen_len_val: 70      # Max length for phoneme sequential generation during validation
  sg_gen_buffer: 5         # Buffer for phoneme SG generation length if guidance used

    # BART Decoder Head (ENABLED)
  train_bart_text_decoder: true
  bart_type: "bart-base"
  bart_text_loss_weight: 1.0 # Split loss between the two heads
  bart_freezing_strategy: "freeze_first_3_layers"
  bart_val_num_beams: 4
  bart_max_target_len: 128
  bart_max_gen_len_val: 150

# --- Data Parameters (for ECoGDataModule) ---
data_params_ecog:
  data_path_ecog:
  # max_input_length: 500
  neural_data_type: "all"
  neural_data_location: "6v"
  daysOnly_train: null
  daysOnly_val: null
  smoothing: true
  gaussianSmoothWidth: 2
  pad_idx: 0
  eos_idx: 42
  sos_idx: 41
  curriculum_learning_enabled: false

# --- Common Training Parameters ---
train_params_common:
  batch_size: 4
  num_workers: 4
  precision: 32
  save_top_k: 2
  early_stopping_patience: 100
  use_scheduler: true
  scheduler_mode: 'min'
  scheduler_factor: 0.5
  scheduler_patience: 10
  scheduler_monitor: 'val_per_sg_epoch' # IMPORTANT: Monitor Whisper's WER
  accumulate_grad_batches: 5

# --- Training Parameters for 'joint_teacher_forcing' stage ---
train_params_joint_tf:
  learning_rate: 1e-4     # Starting learning rate for this stage
  max_epochs: 200            # Number of epochs for this stage
  sg_val_every_n_epochs: 1  # How often to run full SG validation during TF stage validation (optional, can be high if not primary focus)
  weight_decay: 1e-3

# --- STAGE 1 CHECKPOINT FOR WEIGHT INITIALIZATION ---
# This path is used to load the weights from your pre-trained Stage 1 model.

PHONEME_MAP: {
        0: 'PAD', 41: 'SOS', 42: 'EOS',
        1: 'AA', 2: 'AE', 3: 'AH', 4: 'AO', 5: 'AW', 6: 'AY', 
        7: 'B', 8: 'CH', 9: 'D', 10: 'DH', 11: 'EH', 12: 'ER', 
        13: 'EY', 14: 'F', 15: 'G', 16: 'HH', 17: 'IH', 18: 'IY', 
        19: 'JH', 20: 'K', 21: 'L', 22: 'M', 23: 'N', 24: 'NG', 
        25: 'OW', 26: 'OY', 27: 'P', 28: 'R', 29: 'S', 30: 'SH', 
        31: 'T', 32: 'TH', 33: 'UH', 34: 'UW', 35: 'V', 36: 'W', 
        37: 'Y', 38: 'Z', 39: 'ZH', 40: 'SIL'
        }