# Multimodal Benchmark for Clinical Learners

A comprehensive benchmark for evaluating multimodal clinical machine learning models on MIMIC-IV and MIMIC-CXR datasets. This repository provides implementations of state-of-the-art models for handling missing modalities, fusion methods, and longitudinal time series in clinical settings.

## 📋 Table of Contents

- [Project Overview](#-project-overview)
- [Project Structure](#-project-structure)
- [Quick Start](#-quick-start)
- [Model Configurations](#-model-configurations)
- [Dataset Information](#-dataset-information)
- [Adding New Models](#-adding-new-models)
- [Experiment Results](#-experiment-results)
- [Installation](#-installation)
- [License](#-license)
- [Citation](#-citation)

## 🎯 Project Overview

This benchmark focuses on multimodal clinical learning with Electronic Health Records (EHR) and Chest X-rays (CXR). The framework supports:

- **Missing Modality Handling**:  
  - [SMIL: Multimodal Learning with Severely Missing Modality](https://arxiv.org/abs/2103.05677)  
  - [M3Care: Learning with Missing Modalities in Multimodal Healthcare Data](https://arxiv.org/abs/2210.17292)  
  - [ShaSpec: Multi-modal Learning with Missing Modality via Shared-Specific Feature Modelling](https://arxiv.org/abs/2307.14126)  
  - [DrFuse: Disentangled Representation Fusion for Missing Modality Learning](https://arxiv.org/abs/2403.06197)  

- **Fusion Methods**:  
  - [MMTM: Multimodal Transfer Module for CNN Fusion](https://arxiv.org/abs/1911.08670)  
  - [DAFT: Distilling Adversarially Fine-tuned Models for Better OOD Generalization](https://arxiv.org/abs/2208.09139)  
  - [MedFuse: Multi-modal fusion with clinical time-series data and chest X-ray images](https://arxiv.org/abs/2207.07027)  
  - [HEALNet: Multimodal Fusion for Heterogeneous Biomedical Data](https://arxiv.org/abs/2311.09115)  
  - [FlexMoE: Modeling Arbitrary Modality Combination via the Flexible Mixture-of-Experts](https://arxiv.org/abs/2410.08245)  

- **Longitudinal / Temporal / Irregular Time Series Analysis**:  
  - [UTDE: Improving Medical Predictions by Irregular Multimodal Electronic Health Records Modeling](https://arxiv.org/abs/2210.12156)  
  - [UMSE: Learning Missing Modal Electronic Health Records with Unified Multi-modal Data Embedding and Modality-Aware Attention](https://arxiv.org/abs/2305.02504)



**Supported Tasks:**
- Mortality prediction
- Phenotype classification (6 phenotypes)

## 📁 Project Structure

```
benchmark-multimodal-clinical-learners/
├── models/                           # Model implementations
│   ├── base/                        
│   ├── drfuse/                      
│   ├── umse/                        
│   ├── utde/                        
│   ├── healnet/                     
│   ├── smil/                        
│   ├── daft/                        
│   ├── mmtm/                        
│   ├── medfuse/                     
│   ├── m3care/                      
│   ├── flexmoe/                     
│   ├── shaspec/                     
│   ├── registry.py                  # Model registry for registration
│   └── __init__.py                  # Model imports and registry
│
├── datasets/                        # Dataset handling
│   ├── dataset.py                   # Main dataset loader for MIMIC-IV & MIMIC-CXR
│   └── __init__.py
│
├── configs/                         # Model configurations
│   ├── drfuse.yaml                  # DrFuse model config
│   ├── umse.yaml                    # UMSE model config
│   ├── healnet.yaml                 # HealNet model config
│   ├── smil.yaml                    # SMIL model config
│   ├── daft.yaml                    # DAFT model config
│   ├── mmtm.yaml                    # MMTM model config
│   ├── medfuse.yaml                 # MedFuse model config
│   ├── m3care.yaml                  # M3Care model config
│   ├── flexmoe.yaml                 # FlexMoE model config
│   ├── shaspec.yaml                 # ShaSpec model config
│   └── utde.yaml                    # UTDE model config
│
├── utils/                           # Utility functions
│
├── scripts/                         # Training and evaluation scripts
├── experiments/                     # Experiment results and logs
│   └── for{model}/{task}/lightning_logs/
│
├── main.py                         # Main training script
├── arguments.py                    # Command line argument parser
├── extract_mimiciv_benchmark.ipynb  # Data preprocessing notebook for MIMIC-IV benchmark dataset
├── icd_9_10_definitions.yaml       # ICD-9/10 code definitions and mappings for clinical concepts
├── requirements.txt                # Python package dependencies
└── README.md                       # Project documentation
```

## 🚀 Quick Start

Each model has its own dedicated training script in the `scripts/` directory. Let's use DrFuse as an example:

### Understanding Training Scripts

```bash
cd scripts

# Basic training with default parameters
sh train_drfuse.sh
```

### Script Structure Analysis

The `train_drfuse.sh` script contains the following configuration:

```bash
#!/bin/bash

python ../main.py \
    --model drfuse \          # Model name
    --mode train \            # Training mode  
    --task phenotype \        # Task: phenotype or mortality
    --matched \               # Use matched EHR-CXR pairs
    --fold 1 \                # Experiment repetition (1-5)
    --patience 10 \           # Early stopping patience
    --epochs 50 \             # Maximum training epochs
    --batch_size 16 \         # Batch size
    --gpu 0                   # GPU device ID
```

### Customizing Training Parameters


You can also override YAML configuration parameters directly in the script:

```bash
# Add custom learning rate
python ../main.py \
    --model drfuse \
    --mode train \
    --task mortality \
    --matched \
    --fold 1 \
    --lr 0.001 \
    --gpu 0
```


### Running Different Models

Each model follows the same pattern but with different script names:

```bash
# Train other models
sh train_umse.sh      # UMSE model
sh train_healnet.sh   # HealNet model
sh train_smil.sh      # SMIL model
sh train_daft.sh      # DAFT model
sh train_mmtm.sh      # MMTM model
sh train_medfuse.sh   # MedFuse model
sh train_m3care.sh    # M3Care model
sh train_flexmoe.sh   # FlexMoE model
sh train_shaspec.sh   # ShaSpec model
sh train_utde.sh      # UTDE model
```

### Batch Training for Repeated Experiments

To run all 5 repeated experiments automatically, uncomment the loop sections in the script:

   ```bash
# Edit the script to enable 5 repeated experiments
for fold in 1 2 3 4 5
do
    python ../main.py \
        --model drfuse \
        --mode train \
        --task phenotype \
        --matched \
        --fold $fold \
        --patience 10 \
        --epochs 50 \
        --batch_size 16 \
        --gpu 0
done
```

All training results are automatically saved in the `experiments/` directory with timestamps and configuration details.

## ⚙️ Model Configurations

### DrFuse Configuration

```yaml
# Key Parameters
model_name: drfuse
hidden_size: 256                    # Hidden dimension size
ehr_n_head: 4                      # Number of attention heads
ehr_n_layers_distinct: 1           # Distinct layers for EHR
ehr_n_layers_shared: 1             # Shared layers
fusion_method: concate             # Fusion method
disentangle_loss: mse              # Disentanglement loss type

# Loss weights
lambda_disentangle_shared: 1.0     # Shared disentanglement weight
lambda_disentangle_ehr: 1.0        # EHR disentanglement weight
lambda_disentangle_cxr: 1.0        # CXR disentanglement weight
lambda_pred_ehr: 1.0               # EHR prediction weight
lambda_pred_cxr: 1.0               # CXR prediction weight

# Training parameters
lr: 0.0001                         # Learning rate
epochs: 50                         # Number of epochs
patience: 10                       # Early stopping patience
```

### UMSE Configuration

```yaml
# Architecture
model_name: umse
d_model: 128                       # Model dimension
variables_num: 322                 # Number of EHR variables
num_layers: 4                      # Number of transformer layers
num_heads: 8                       # Number of attention heads
n_modality: 2                      # Number of modalities
bottlenecks_n: 4                   # Number of bottlenecks
dropout: 0.3                       # Dropout rate
max_ehr_len: 2000                  # Maximum EHR sequence length

# Training
batch_size: 32                     # Batch size
lr: 0.0004                         # Learning rate
beta_1: 0.99                       # Adam beta1 parameter
```

### UTDE Configuration

```yaml
# Architecture
model_name: udte
num_heads: 8                       # Number of attention heads
tt_max: 48                         # Maximum time steps
orig_reg_d_ts: 498                 # Original EHR dimension
kernel_size: 1                     # Convolution kernel size
embed_dim: 128                     # Embedding dimension
embed_time: 64                     # Time embedding dimension
cross_layers: 3                    # Number of cross-attention layers
dropout: 0.3                       # Dropout rate

# Training
batch_size: 32                     # Batch size
lr: 0.0004                         # Learning rate
beta_1: 0.99                       # Adam beta1 parameter
epochs: 50                         # Number of epochs
patience: 10                       # Early stopping patience
```

### HealNet Configuration

```yaml
# Fusion Parameters
model_name: healnet
n_modalities: 2                    # Number of modalities
depth: 3                           # Number of fusion layers
latent_channels: 128               # Number of latent tokens
latent_dim: 256                    # Dimension of latent tokens
cross_heads: 8                     # Cross-attention heads
latent_heads: 8                    # Self-attention heads
cross_dim_head: 64                 # Cross-attention head dimension
latent_dim_head: 64                # Self-attention head dimension

# Advanced Features
fourier_encode_data: true          # Use Fourier encoding
num_freq_bands: 2                  # Frequency bands
max_freq: 10.0                     # Maximum frequency
snn: true                          # Self-normalizing networks
weight_tie_layers: false           # Weight sharing across layers

# Dropout
attn_dropout: 0.3                  # Attention dropout
ff_dropout: 0.3                    # Feed-forward dropout
```

### SMIL Configuration

```yaml
# Meta-learning
model_name: smil
inner_loop: 1                      # Inner loop iterations
lr_inner: 0.01                     # Inner learning rate
mc_size: 30                        # Monte Carlo size

# CXR Processing
cxr_mean_path: ../models/smil/cxr_mean
n_clusters: 10                     # Number of clusters
cxr_img_size: 224                  # Image size

# Knowledge Distillation
alpha: 0.05                        # Feature distillation weight
beta: 0.05                         # EHR mean distillation weight
temperature: 3.0                   # Distillation temperature

# Training
lr: 0.001                          # Learning rate
batch_size: 32                     # Batch size
```

### DAFT Configuration

```yaml
# Architecture
model_name: daft
vision_backbone: resnet34          # Vision backbone
dim: 256                           # Hidden dimension
layers: 2                          # Number of layers
layer_after: 3                     # Layer position for fusion
daft_activation: linear            # Activation function

# Alignment
align: 0.0                         # Alignment loss weight
pretrained: false                  # Use pretrained weights

# Training
lr: 0.0001                         # Learning rate
batch_size: 16                     # Batch size
epochs: 50                         # Number of epochs
patience: 10                       # Early stopping patience
```

### MMTM Configuration

```yaml
# Architecture
model_name: mmtm
vision_backbone: resnet34          # Vision backbone
vision_num_classes: 6              # Vision output classes
fusion_type: mmtm                  # Fusion type
layer_after: 4                     # Fusion layer position
dim: 256                           # Hidden dimension
layers: 2                          # Number of layers
input_dim: 498                     # EHR dimension
num_classes: 6                     # Output classes

# Training
lr: 0.0001                         # Learning rate
batch_size: 16                     # Batch size
epochs: 50                         # Number of epochs
patience: 10                       # Early stopping patience
beta_1: 0.9                        # Adam beta1 parameter
dropout: 0.3                       # Dropout rate

# Options
pretrained: false                  # Use pretrained weights
align: 0.0                         # Alignment loss weight
```

### MedFuse Configuration

```yaml
# Architecture
model_name: medfuse
hidden_size: 256                   # Hidden dimension
num_heads: 8                       # Number of attention heads
num_layers: 4                      # Number of layers
dropout: 0.3                       # Dropout rate
input_dim: 498                     # EHR dimension
num_classes: 6                     # Output classes

# Vision
vision_backbone: resnet50          # Vision backbone
pretrained: true                   # Use pretrained weights

# Training
lr: 0.0001                         # Learning rate
batch_size: 16                     # Batch size
epochs: 50                         # Number of epochs
patience: 10                       # Early stopping patience
```

### M3Care Configuration

```yaml
# Architecture
model_name: m3care
input_dim: 24                      # Input dimension
hidden_dim: 256                    # Hidden dimension
dropout: 0.2                       # Dropout rate
stab_reg_lambda: 0.2               # Stability regularization weight
num_classes: 25                    # Output classes

# EHR Processing
ehr_n_head: 2                      # EHR attention heads
ehr_dropout: 0.2                   # EHR dropout rate

# Training
lr: 0.0001                         # Learning rate
beta_1: 0.9                        # Adam beta1 parameter
```

### FlexMoE Configuration

```yaml
# Architecture
model_name: flexmoe
hidden_dim: 512                    # Hidden dimension
num_patches: 16                    # Number of patches
num_layers: 2                      # Number of layers
num_layers_pred: 4                 # Prediction layers
num_experts: 16                    # Number of experts
num_routers: 2                     # Number of routers
top_k: 4                           # Top-k experts
num_heads: 8                       # Number of attention heads
dropout: 0.3                       # Dropout rate

# Vision
vision_backbone: resnet50          # Vision backbone
pretrained: true                   # Use pretrained weights

# Task
input_dim: 498                     # Input dimension
num_classes: 6                     # Output classes

# Training
lr: 0.0001                         # Learning rate
weight_decay: 0.01                 # Weight decay
gate_loss_weight: 0.01             # Gate loss weight
beta_1: 0.99                       # Adam beta1 parameter
epochs: 50                         # Number of epochs
pct_start: 0.3                     # Warmup percentage
div_factor: 25                     # Learning rate division factor
final_div_factor: 1e4              # Final division factor
```

### ShaSpec Configuration

```yaml
# Architecture
model_name: shaspec
dim: 256                           # Shared feature dimension
alpha: 0.1                         # Shared loss weight
beta: 0.02                         # Domain classification loss weight
input_dim: 498                     # EHR feature dimension
dropout: 0.2                       # Dropout rate
layers: 2                          # LSTM layers
num_classes: 6                     # Output classes

# Vision
vision_backbone: resnet50          # Vision backbone
pretrained: true                   # Use pretrained weights
vision_num_classes: 1              # Vision output classes

# Transformer
nhead: 8                           # Attention heads
num_layers: 3                      # Transformer layers
max_seq_len: 500                   # Maximum sequence length

# Training
lr: 0.0001                         # Learning rate
batch_size: 16                     # Batch size
epochs: 50                         # Number of epochs
patience: 10                       # Early stopping patience
beta_1: 0.9                        # Adam beta1 parameter
weight_std: true                   # Weight standardization
```

## 📊 Dataset Information

### MIMIC-IV Dataset
- **EHR Data**: Electronic health records with 498 features
- **Tasks**: Mortality prediction, phenotype classification
- **Splits**: 5 repeated experiments 

### MIMIC-CXR Dataset
- **Images**: Chest X-ray images (resized to 224x224)
- **Preprocessing**: ResNet/Swin Transformer feature extraction
- **Pairing**: Matched with EHR data by patient ID and time

### Experimental Setup
- **Repeated Experiments**: Each model is trained 5 times (fold 1-5) to ensure statistical reliability
- **Data Splits**: Train/validation/test splits are consistent across all experiments, with only the random initialization varying
- **Evaluation**: Results are averaged across all 5 repetitions to provide robust performance metrics

### Data Paths Configuration

```python
# Default paths (modify in arguments.py)
--ehr_root: /path/to/ehr/data
--resized_cxr_root: /path/to/cxr/data
--pkl_dir: /path/to/pkl/data
--image_meta_path: /path/to/metadata.csv
```

## 🔧 Adding New Models

### Step 1: Create Model Directory

```bash
mkdir models/your_model_name
```

### Step 2: Implement Model Class

```python
# models/your_model_name/your_model.py
from ..base import BaseFuseTrainer
from ..registry import ModelRegistry

@ModelRegistry.register('your_model_name')
class YourModel(BaseFuseTrainer):
    def __init__(self, hparams):
        super().__init__(hparams)
        # Initialize your model components
        
    def forward(self, data_dict):
        # Implement forward pass
        pass
   
    def training_step(self, batch, batch_idx):
        # Implement training logic
        pass
```

### Step 3: Register Model

```python
# models/your_model_name/__init__.py
from .your_model import YourModel

# models/__init__.py
from .your_model_name import YourModel
```

### Step 4: Add Configuration

```yaml
# configs/your_model_name.yaml
model_name: your_model_name
# Add your model-specific parameters
hidden_size: 256
lr: 0.0001
epochs: 50
patience: 10
# ... other parameters
```

### Step 5: Create Training Script

```bash
# scripts/train_your_model_name.sh
#!/bin/bash

python ../main.py \
    --model your_model_name \
    --mode train \
    --task phenotype \
    --matched \
    --fold 1 \
    --patience 10 \
    --epochs 50 \
    --batch_size 16 \
    --gpu 0
```

## 📈 Experiment Results

Results are automatically saved in the `experiments/` directory:

```
experiments/
├── for_{model_name}/
│   ├── {task}/
│   │   ├── lightning_logs/
│   │   │   ├── version_X/
│   │   │   │   ├── checkpoints/
│   │   │   │   ├── events.out.tfevents.*
│   │   │   │   ├── hparams.yaml
│   │   │   │   ├── metrics.csv
│   │   │   │   └── test_set_results.yaml
```

### Key Metrics
- **AUROC** (Area Under ROC Curve): Measures the model's ability to distinguish between classes across different thresholds
- **AUPRC** (Area Under Precision-Recall Curve): Evaluates model performance when classes are imbalanced
- **Accuracy (ACC)**: Overall proportion of correct predictions
- **Precision**: Proportion of true positives among all positive predictions
- **Recall**: Proportion of actual positives correctly identified
- **Specificity**: Proportion of actual negatives correctly identified
- **F1 Score**: Harmonic mean of precision and recall, balancing both metrics

These metrics are automatically calculated and logged during training and testing, with results saved in the experiment logs directory.

## 📦 Installation

### Requirements

```bash
pip install -r requirements.txt
```

### Key Dependencies

```
torch>=1.12.0
lightning>=2.0.0
torchvision>=0.13.0
numpy>=1.21.0
pandas>=1.3.0
scikit-learn>=1.0.0
PyYAML>=6.0
tensorboard>=2.8.0
```

### GPU Requirements
- NVIDIA GPU with CUDA support
- Minimum 16GB GPU memory recommended


## 📝 License

This source code is released under the MIT license, [here](LICENSE). We do not own any of the datasets used or included in this repository.

## 📚 Citation

If you use this benchmark in your research, please cite:

```bibtex
@article{benchmark-multimodal-clinical-learners,
  title={},
  author={},
  journal={},
  year={2025}
}
```

**Note**: This benchmark is designed for research purposes. Ensure compliance with data usage agreements and ethical guidelines when using clinical datasets.