# Configuration for AE-Enhanced Mangrove Segmentation Training - Separability + Orthogonality Loss

data:
  root_dir: 'datasets/GEE/sentinel-2_w0.45_0.1_0.45_split' # Path to the processed dataset with AE features
  
model:
  name: 'AEUnetPlusPlus'
  args:
    encoder_name: 'resnet34' # Encoder for the UNet++ backbone
    in_img: 6 # Number of satellite image channels (B, G, R, NIR, SWIR1, SWIR2)
    ae_dim: 64 # Number of AE feature channels
    D: 3 # Number of projected AE channels
    classes: 1 # Binary segmentation (same as SegmentationTrainer)
    bce_weight: 2.0 # Weight for BCE loss in DiceBCE combination (same as SegmentationTrainer)

train:
  seed: 42
  uid: 'ae_sep_ortho_v1' # Unique ID for this training run
  
  # Training Hyperparameters
  epoch: 50
  batch_size: 8  # Much smaller batch size to prevent memory pressure from large AE features
  learning_rate: 0.001
  
  # Loss Function Configuration - Separability + Orthogonality Loss
  loss:
    name: 'AELoss'  # Custom multi-component loss
    lambda_sep: 0.5     # Weight for separability loss
    lambda_ortho: 0.01  # Weight for orthogonality loss (keep small)
    lambda_tv: 0.0      # Disable total variation loss
    lambda_mag: 0.0     # Weight for magnitude penalty (usually 0, use weight_decay instead)
    # Note: Dice+BCE loss weights are controlled by bce_weight in model.args
  
  # Optimizer
  optimizer:
    name: 'AdamW'
    args:
      weight_decay: 0.01
      betas: [0.9, 0.999]

  # Learning Rate Scheduler
  scheduler:
    name: 'CosineAnnealingWarmRestarts'
    args:
      T_0: 15 # Number of iterations for the first restart
      eta_min: 1.0e-6
      T_mult: 2

  log_dir: logs
  # Hardware and DDP
  n_workers: 2  # Further reduced for large AE dataset to prevent memory pressure
  no_ddp: false # Set to true if you want to run on a single GPU without DDP
  no_save: false # Set to true to disable checkpoint saving

  # Early Stopping
  patience: 20  # Enable early stopping for longer training
  
  # Image logging interval (in steps) for training
  log_image_interval: 200
