# SMIL Model Configuration
model_name: smil

# -----------------------
# Encoder Selection
# -----------------------
ehr_encoder: transformer        # lstm, transformer
cxr_encoder: resnet50    # resnet50, vit_b_16
pretrained: true

# -----------------------
# EHR Encoder Parameters
# -----------------------
# LSTM specific
ehr_num_layers: 1        # Number of LSTM layers
ehr_bidirectional: true  # Bidirectional LSTM

# Transformer specific  
ehr_n_head: 4           # Number of attention heads
ehr_n_layers: 1         # Number of transformer layers
max_len: 500            # Maximum sequence length

# -----------------------
# Model Architecture
# -----------------------
input_dim: 24  # EHR dimension
hidden_dim: 256
dropout: 0.2

# Fusion
fusion_type: smil

# -----------------------
# Task Configuration
# -----------------------
task: phenotype  # or mortality
num_classes: 25   # phenotype: 25, mortality: 1
vision_backbone: resnet50  # Legacy parameter, now handled by cxr_encoder
layers: 2 # Legacy parameter, now handled by ehr_num_layers

# -----------------------
# CXR K-means Config (Pre-computed)
# -----------------------
cxr_mean_path: ../models/smil/cxr_mean
# cxr_mean_name will be auto-generated based on fold, data_type, encoder, and n_clusters
# Format: cxr_mean_fold{fold}_{data_type}_{encoder}_{clusters}clusters.npy
# Example: cxr_mean_fold1_matched_resnet50_10clusters.npy or cxr_mean_fold1_full_resnet50_10clusters.npy
cxr_img_size: 224  # Legacy parameter, not used when loading pre-computed
n_clusters: 10     # Should match the pre-computed file

# -----------------------
# Meta-learning Config
# -----------------------
inner_loop: 1
lr_inner: 0.01
mc_size: 30

# -----------------------
# Training Config
# -----------------------
mode: train
batch_size: 32
epochs: 50
patience: 10
lr: 0.0001

# -----------------------
# Loss Config
# -----------------------
alpha: 0.05  # Feature distillation weight
beta: 0.05   # EHR mean distillation weight
temperature: 3.0  # Knowledge distillation temperature

use_label_weights: false  # Enable/disable label weights
label_weight_method: balanced  # Options: 'balanced', 'inverse', 'sqrt_inverse', 'log_inverse', 'custom'

# -----------------------
# finetune parameters
# -----------------------
# inner_loop: 1、2、3
# mc_size: 10、20、30
# lr_inner: 0.001、0.01、0.05
# alpha: 0.05、0.1、0.2
# beta: 0.05、0.1、0.2
# temperature: 1.0、2.0、3.0

