# Configuration file for two-stage MRI reconstruction training with Nirvana model
# This file contains all training parameters and can be easily modified

# Model Configuration
model:
  config_path: "./nirvana_1_3B.json"
  backbone_path: "./model.safetensors"
  num_coils: 16
  img_size: 320
  vit_embed_dim: 768
  vit_num_layers: 6
  vit_num_heads: 12

# Data Configuration
data:
  data_dir: "./brain_multicoil_train"
  num_workers: 8
  kspace_mask:
    center_fractions: [0.08]
    accelerations: [4]

# Stage 1: K-space Encoder Training
stage1:
  epochs: 100
  batch_size: 4
  learning_rate: 3.0e-4
  weight_decay: 1.0e-4
  max_grad_norm: 1.0
  warmup_ratio: 0.1
  label_smoothing: 0.1
  l2_reg: 0.0
  save_dir: "./stage1_checkpoints"
  save_interval: 10
  log_interval: 100
  log_gradients: true
  
  # Loss function: Cross-entropy on MRI analysis tokens
  loss:
    type: "cross_entropy"
    ignore_index: -100
    label_smoothing: 0.1

# Stage 2: Image Decoder Training
stage2:
  epochs: 100
  batch_size: 4
  learning_rate: 1.0e-4
  weight_decay: 1.0e-4
  max_grad_norm: 1.0
  warmup_ratio: 0.1
  l2_reg: 0.0
  save_dir: "./stage2_checkpoints"
  save_interval: 10
  log_interval: 100
  log_gradients: true
  log_image_stats: true
  
  # Image decoder configuration
  decoder:
    type: "full"  # Options: "full" (160M params), "lightweight" (80M params)
    hidden_dim: 512
    use_bilinear_upsample: true
    token_dim: 768  # Should match vit_embed_dim
  
  # Loss function: SSIM loss for image reconstruction
  loss:
    type: "ssim"
    # SSIM loss parameters (using fastMRI default)

# Training Configuration
training:
  mixed_precision: "bf16"
  gradient_accumulation_steps: 1
  early_stopping_patience: 20
  resume_training: false
  
  # Checkpoint management
  checkpoint:
    save_best: true
    save_last: true
    max_checkpoints: 5
    
  # Logging
  logging:
    level: "INFO"
    log_to_file: true
    log_to_console: true
    tensorboard: false
    
  # Validation
  validation:
    enabled: false
    interval: 1
    metrics: ["ssim", "psnr", "mse"]

# Hardware Configuration
hardware:
  device: "auto"  # auto, cuda, cpu
  num_gpus: 1
  memory_fraction: 0.9
  
# Optimization
optimization:
  optimizer: "adamw"
  scheduler: "cosine_with_warmup"
  
  # AdamW parameters
  adamw:
    betas: [0.9, 0.999]
    eps: 1e-8
    
  # Learning rate scheduling
  lr_schedule:
    warmup_ratio: 0.1
    min_lr_ratio: 0.01
    
# Data Augmentation (if applicable)
augmentation:
  enabled: false
  random_flip: false
  random_rotation: false
  random_brightness: false
  random_contrast: false

# Model Architecture Details
architecture:
  # K-space encoder (VarNet + ViT)
  kspace_encoder:
    varnet:
      num_cascades: 12
      chans: 18
      sens_chans: 8
    vit:
      embed_dim: 768
      num_layers: 6
      num_heads: 12
      mlp_ratio: 4.0
      dropout: 0.1
      
  # Image decoder (U-Net)
  image_decoder:
    unet:
      in_channels: 512  # Should match decoder hidden_dim
      out_channels: 1
      features: [64, 128, 256, 512, 1024]  # For full decoder
      # features: [32, 64, 128, 256]  # For lightweight decoder
      bilinear: true
      dropout: 0.1
      
  # Nirvana backbone (frozen)
  backbone:
    model_type: "transformer_rnn"
    hidden_size: 768
    num_layers: 24
    num_attention_heads: 12
    intermediate_size: 3072
    vocab_size: 50257

# Loss Functions
losses:
  # Stage 1: Cross-entropy for MRI analysis tokens
  stage1:
    cross_entropy:
      ignore_index: -100
      label_smoothing: 0.1
      reduction: "mean"
      
  # Stage 2: SSIM for image reconstruction
  stage2:
    ssim:
      window_size: 11
      sigma: 1.5
      data_range: 1.0
      channel: 1
      size_average: true
      
    # Additional losses (optional)
    l1: 0.0
    l2: 0.0
    perceptual: 0.0

# Evaluation Metrics
metrics:
  image_quality:
    - ssim
    - psnr
    - nmse

# Output Configuration
output:
  save_dir: "./outputs"
  save_images: true
  save_metrics: true
  save_predictions: true
  
  # Image saving
  images:
    format: "png"
    quality: 95
    save_original: true
    save_reconstructed: true
    save_difference: true
    
  # Metrics saving
  metrics:
    format: "json"
    save_per_epoch: true
    save_final: true

# Environment
environment:
  seed: 42
  deterministic: false
  benchmark: true
  cudnn_benchmark: true
  
# Logging and Monitoring
monitoring:
  # Progress tracking
  progress:
    show_progress_bar: true
    log_every_n_steps: 100
    
  # Performance monitoring
  performance:
    log_memory: true
    log_gpu_utilization: true
    log_compute_time: true
    
  # Model monitoring
  model:
    log_parameter_norms: true
    log_gradient_norms: true
    log_learning_rate: true 