# Low-Rank Bayesian Neural Networks on MIMIC-III

This repository contains a comprehensive implementation and evaluation of **Low-Rank Bayesian Neural Networks** for ICU mortality prediction using the MIMIC-III clinical database.

## Overview

This project compares different Bayesian deep learning architectures for **performance**, **uncertainty quantification** and **out-of-distribution (OOD) detection** in clinical prediction tasks:

- **Full-Rank Bayes by Backprop (BBB)**: Standard variational inference with diagonal Gaussian posteriors
- **Low-Rank Gaussian Factorization**: Memory-efficient approximation using W ≈ AB^T decomposition
- **Rank-1 Multiplicative**: (Dusenberry et al., 2020)
- **Deep Ensemble**: Baseline using multiple deterministic networks
- **Deterministic Dense**: Single deterministic network for comparison

## Key Features

✓ **Parameter Efficiency**: Low-rank models reduce parameters from O(d_in × d_out) to O((d_in + d_out) × rank)
✓ **Uncertainty Quantification**: Mutual Information (MI) based epistemic uncertainty estimation
✓ **OOD Detection**: Newborn patients used as natural OOD distribution
✓ **Calibration Analysis**: Expected Calibration Error (ECE) with multiple binning strategies
✓ **Seed Robustness**: Multi-seed experiments to verify reproducibility
✓ **Comprehensive Metrics**: AUROC, AUPR, NLL, ECE, OOD detection performance

---

## Project Structure

```
MIMIC-3 experiment/
│
├── modules/                          # Core implementation modules
│   ├── config.py                    # Configuration and hyperparameters
│   ├── bayesian_layers.py           # Bayesian layer implementations
│   ├── model_builders.py            # Model architecture builders
│   ├── inference.py                 # MC sampling and uncertainty quantification
│   ├── metrics.py                   # Evaluation metrics (AUROC, AUPR, ECE, MI)
│   ├── evaluation.py                # High-level evaluation orchestration
│   ├── training.py                  # Training routines
│   ├── data_utils.py                # Data loading and preprocessing
│   ├── visualization.py             # Plotting and visualization functions
│   └── MIMIC_3_data_preprocessing.py # MIMIC-III data preprocessing pipeline
│
├── MIMIC-3 final.ipynb              # Main experiment notebook
├── seed_robustness_experiment.ipynb # Multi-seed robustness analysis
│
├── data2/                           # Data directory (not included)
│   ├── processed/                   # Processed CSV files
│   └── interim/                     # Intermediate files
│
├── checkpoints/                     # Saved model weights
├── figures/                         # Generated plots and visualizations
├── results_csv/                     # Evaluation results in CSV format
│   ├── results_mi_based.csv        # Main MI-based evaluation results
│   ├── rank_sweep_results.csv      # Rank selection experiment results
│   └── model_statistics.csv        # Parameter counts and memory usage
t
│
└── seed_robustness_results/         # Multi-seed analysis
    ├── seed_robustness_data.json    # Results across 5 seeds
    ├── analysis.json                # Statistical analysis
    └── seed_robustness_analysis.py  # Analysis script
```

---

## Installation

### Prerequisites

- Python 3.8+
- TensorFlow 2.15.0
- CUDA-compatible GPU 
- MIMIC-III v1.4 database access

### Dependencies

```bash
pip install tensorflow==2.15.0
pip install tensorflow-probability
pip install pandas numpy scikit-learn
pip install matplotlib seaborn
pip install duckdb  # For efficient data preprocessing
```

### Hardware Setup

**Hardware Specifications:**
- **GPU**: NVIDIA H100 NVL (94 GB VRAM)
- **Driver Version**: 580.95.05
- **CUDA Version**: 13.0
- **Max Power**: 400W
- **Memory**: 95,830 MiB total

---

## Data Access

**Important**: The MIMIC-III dataset is **not included** in this repository due to the PhysioNet Data Use Agreement. To reproduce the experiments, you must obtain access to the data yourself.

### Steps to Access MIMIC-III:

1. **Create a PhysioNet account**: Go to [https://physionet.org/](https://physionet.org/) and register
2. **Complete CITI training**: Required for credentialed access to restricted datasets
3. **Request access to MIMIC-III**: Visit [MIMIC-III Clinical Database v1.4](https://physionet.org/content/mimiciii/1.4/) and submit your access request
4. **Download the data**: Once approved, download the database files

### Preprocessing

After obtaining the raw MIMIC-III data, run the preprocessing script included in `modules/`:

```bash
python modules/MIMIC_3_data_preprocessing.py \
    --data_dir /path/to/mimic-iii-clinical-database-1.4 \
    --interim_dir data2/interim \
    --processed_dir data2/processed
```

This will generate all necessary files in `data2/` including training/test splits and the OOD (newborn) cohort.

---

**Key outputs:**
- Trained models saved to `checkpoints/`
- Evaluation metrics in `results_csv/results_mi_based.csv`
- Visualizations in `figures/`

### 3. Seed Robustness Analysis

Run [seed_robustness_experiment.ipynb](seed_robustness_experiment.ipynb) to verify reproducibility across 5 random seeds (42, 123, 256, 789, 2024).

---

### Preprocessing Pipeline

The preprocessing pipeline ([MIMIC_3_data_preprocessing.py](modules/MIMIC_3_data_preprocessing.py)) reproduces the methodology from [Pacmed/aisg_2019](https://github.com/Pacmed/aisg_2019):

1. Patient demographics from `PATIENTS`, `ADMISSIONS`, `ICUSTAYS` tables
2. Vital signs aggregation (mean/std per ICU stay)
3. Lab results aggregation (mean/std per ICU stay)
4. Outlier removal (8×IQR rule)
5. Median imputation and min-max scaling

---

### Hyperparameters

| Parameter | Value | Notes |
|-----------|-------|-------|
| Batch size | 64 | For Bayesian models |
| Learning rate | 1e-3 | Adam optimizer |
| Epochs (Bayesian) | 256 | Full training |
| Epochs (Deterministic) | 32 | Prevents overfitting |
| KL scaling | 0.5 / N_batches | Minibatch VI correction |
| Class weights | {0: 1.0, 1: 11.88} | Handles 8.4% positive class |

### Reproducibility

All models trained with:
- Fixed random seeds 
- Deterministic TensorFlow operations (`TF_DETERMINISTIC_OPS=1`)
- GPU memory growth enabled
- Float32 precision (no mixed precision)

---

## Evaluation Metrics

### Classification Performance

- **AUROC**: Area under ROC curve (P(y_pred_pos > y_pred_neg))
- **NLL**: Negative log-likelihood (binary cross-entropy)

### Uncertainty Quality

**MI-based (Primary)**:
- **AUPR-Success (MI)**: How well low MI identifies correct predictions
- **AUPR-Error (MI)**: How well high MI identifies incorrect predictions
- **Mutual Information**: I[y; ω | x, D] = H[y|x,D] - E[H[y|x,ω]]

**STD-based (Comparison)**:
- Same metrics using predictive standard deviation

### OOD Detection

- **AUROC-OOD (MI)**: Discriminating in-domain vs OOD using MI
- **AUPR-In-Domain (MI)**: Low MI identifies in-domain samples
- **AUPR-OOD (MI)**: High MI identifies OOD samples
- **MI Ratio**: Mean MI (OOD) / Mean MI (In-domain)

### Calibration

- **ECE**: Expected Calibration Error with multiple binning strategies:
  - Equal-width (10, 15, 20 bins)
  - Equal-mass (10, 15, 20 bins)
- **Best ECE**: Minimum ECE across all configurations

---
## Acknowledgments

- **MIMIC-III**: Johnson et al. (2016) - Clinical database
- **Preprocessing Pipeline**: Based on [Pacmed/aisg_2019](https://github.com/Pacmed/aisg_2019)
- **Bayes by Backprop**: Blundell et al. (2015), Dusenberry et al. (2020)

---

