# DAFT Model Configuration
model_name: daft

# -----------------------
# Encoder Selection
# -----------------------
ehr_encoder: transformer  # Options: 'lstm', 'transformer'
cxr_encoder: resnet50  # Options: 'resnet50', 'vit_b_16'

# -----------------------
# EHR Encoder Params (for transformer)
# -----------------------
ehr_n_head: 4
ehr_n_layers: 1

# -----------------------
# DAFT Fusion Params
# -----------------------
layer_after: -1  # Which layer to apply DAFT fusion (-1 for all layers)
daft_activation: linear  # Options: 'linear', 'sigmoid', 'tanh'

# -----------------------
# Model Structure Params
# -----------------------
dim: 256  # Hidden dimension
input_dim: 49  # EHR input dimension
num_classes: 25

# -----------------------
# Training Config
# -----------------------
mode: train
task: phenotype
batch_size: 16
epochs: 50
lr: 0.0001
dropout: 0.2
patience: 10

# -----------------------
# Data Config
# -----------------------
data_pairs: paired_ehr_cxr

# -----------------------
# Load / Resume Options
# -----------------------
pretrained: true  # Whether to use pretrained weights for CXR encoder
load_state: null

use_label_weights: false  # Enable/disable label weights
label_weight_method: balanced  # Options: 'balanced', 'inverse', 'sqrt_inverse', 'log_inverse', 'custom'