# Training configuration for Polygenic Risk Score prediction pipeline

# Model configuration for PRS prediction
model:
  model_type: "mlp"  # Options: mlp, cnn, transformer, attention, bayesian/bnn
  
  # MLP specific parameters (recommended for PRS)
  hidden_dims: [1000, 250, 50]  # Deeper network for complex genetic patterns
  dropout_rate: 0.2
  activation: "relu"  # Options: relu, elu, leaky_relu
  batch_norm: true
  
  # CNN specific parameters (if model_type: cnn)
  channels: [32, 64, 128]
  kernel_sizes: [7, 5, 3]
  pool_sizes: [2, 2, 2]
  fc_dims: [256, 128]
  
  # Transformer specific parameters (if model_type: transformer)
  d_model: 256
  n_heads: 8
  n_layers: 4
  d_ff: 1024
  
  # Attention specific parameters (if model_type: attention)
  hidden_dim: 256
  attention_dim: 128
  n_attention_heads: 4
  
  # Bayesian Neural Network parameters (if model_type: bayesian/bnn)
  # hidden_dims: [512, 256, 128]  # Smaller network due to uncertainty modeling
  prior_var: 1.0  # Prior variance for weight distributions
  n_samples: 10  # Number of forward passes for uncertainty estimation
  kl_weight: 0.01  # Weight for KL divergence in ELBO loss

# Optimizer configuration
optimizer:
  type: "adamw"  # Options: adam, adamw, sgd
  lr: 0.001
  weight_decay: 0.01
  # SGD specific
  momentum: 0.9

# Learning rate scheduler
scheduler:
  type: "reduce_on_plateau"  # Options: step, cosine, reduce_on_plateau
  # ReduceLROnPlateau specific
  patience: 10
  factor: 0.1
  # StepLR specific
  step_size: 30
  gamma: 0.1

# Training parameters
max_epochs: 20
batch_size: 500
num_workers: 4
gradient_clip: 1.0

# Loss function
loss_function: "auto"  # Options: mse, mae, huber, crossentropy, auto (recommended)

# Early stopping
early_stopping:
  patience: 20
  min_delta: 0.0001

# Data augmentation
augment_train: false
augmentation_params:
  snp_dropout: 0.1
  noise_std: 0.05
  shuffle_regions: 10

# Logging with Weights & Biases
use_wandb: true  # Enable W&B logging
wandb_project: "prs-prediction-A"
wandb_entity: null  # Set to your W&B username/team
log_interval: 10
save_checkpoint_interval: 5

# Hardware
use_cuda: false
cuda_deterministic: false
seed: 42

# Data preprocessing
preprocessing:
  impute_missing: false
  standardize: false
  remove_low_maf: false
  maf_threshold: 0.01
  
# Cross-validation
cv:
  n_folds: 5
  stratified: true

# Evaluation metrics
metrics:
  - mse
  - mae
  - r2
  - correlation
  - explained_variance

# Hyperparameter search space (for Optuna)
hyperparameter_search:
  model_type:
    type: categorical
    choices: ["mlp", "cnn", "transformer", "attention"]
  
  learning_rate:
    type: float
    low: 1e-5
    high: 1e-2
    log: true
  
  dropout_rate:
    type: float
    low: 0.1
    high: 0.5
  
  batch_size:
    type: categorical
    choices: [16, 32, 64, 128]
  
  weight_decay:
    type: float
    low: 1e-6
    high: 1e-2
    log: true