medalign — Lightweight implementation scaffold for ILCR MedAlign

A modular, research-oriented codebase scaffold implementing the components described in the ILCR MedAlign paper (CAFW, HSF, CDE, PCRL, RL-Gating, DP aggregation, compression, etc.).
This repository provides minimal, well-documented skeletons for each major component so you can quickly fill in algorithmic details, run experiments, and iterate on reproducible evaluations.

Status: scaffold / skeleton code — core interfaces and small working examples included. Replace TODO markers with paper-specific equations and hyperparameters.

Repository layout:
medalign/
├─ configs.py                # default hyperparameters and paths
├─ utils.py                  # logging, checkpointing, metrics (Accuracy/MCC/SNR), helpers
├─ data.py                   # PyTorch Dataset, augmentations (temporal warp, noise, channel mask)
├─ models.py                 # encoder, projection head, classifier, policy skeleton
├─ cafw.py                   # Context-Aware Feature Weighting (ontology embedding + scorer)
├─ hsf.py                    # Heterogeneous Stream Fusion (cross-attention fusion)
├─ cde.py                    # Clinical Dependency Encoder (dense GAT layers)
├─ pcrl.py                   # Prototype-Consistent Representation Learning (prototypes + NT-Xent)
├─ rl_gating.py              # RL-Gating environment and small REINFORCE agent
├─ privacy.py                # DP helpers: clipping, sigma calc, secure_aggregate placeholder
├─ compression.py            # sparsify, quantize, residual compensation
├─ client.py                 # client pipeline that composes modules and returns simulated updates
├─ server.py                 # server pipeline for client selection and DP aggregation
├─ train.py                  # simple experiment driver (federated simulation + export)
├─ eval.py                   # evaluation helpers (accuracy, mcc, snr, psnr, ssim placeholder)
├─ deploy_edge.py            # edge export/inference helpers (Raspberry Pi simulation)
├─ scripts/
│   ├─ run_fed_experiment.py
│   ├─ run_client_sim.py
│   └─ analyze_ablation.py
└─ README.md


Design & modules:

configs.py: Contains default hyperparameters and a get_config helper. Override defaults with a JSON config file or CLI arguments.

utils.py: Common utilities for seeding, logging, checkpointing, and classic metrics (accuracy, Matthews correlation coefficient, SNR in dB).

data.py: Minimal PyTorch Dataset for multivariate time-series with augmentation utilities:

temporal warp (±15% by default)

additive Gaussian noise to match target SNR (20 dB default)

channel masking

models.py: Lightweight TimeSeriesEncoder (1D conv stack), ProjectionHead, ClassifierHead, and PolicyNetwork for RL gating.

cafw.py: Ontology embedding + scoring network to compute per-feature weights (CAFW).

hsf.py: Cross-attention fusion of heterogeneous modality embeddings.

cde.py: Dense GAT-like layers for clinical dependency encoding. Includes adjacency builder utilities.

pcrl.py: Prototype store, NT-Xent contrastive loss, and simple compression helpers for prototype deltas.

rl_gating.py: Environment (state builder) and a simple REINFORCE agent. Replace with actor-critic / PPO if needed.

privacy.py: Computes Gaussian noise std for DP, provides vector clipping and a simulated secure_aggregate that adds noise.

compression.py: Two-mode compression: top-k sparsification, uniform quantization, and residual compensation for error feedback.

client.py / server.py: High-level composition for federated simulation. They use the modules above and demonstrate the data and update flow.

train.py: High-level driver for running federated experiments and saving metrics.

eval.py: Evaluation helpers (accuracy, MCC, SNR, PSNR/SSIM placeholders).

deploy_edge.py: Build a stripped model to deploy to constrained devices; includes simple inference wrapper.


Extending the scaffold:

Suggested next steps to implement the full paper pipeline:

Implement the exact CAFW scoring and ontology processing from the paper (replace simple scorer with domain-specific function).

Replace the GAT dense implementation with sparse messages or use PyG for large graphs.

Implement PCRL prototype alignment rules and server-client prototype synchronization.

Replace REINFORCE with an actor-critic agent (or PPO) for RL-Gating sample efficiency.

Add a real DP accountant (RDP-based) and secure aggregation (MPC or aggregator with encryption) for production privacy.

Add unit tests for each module and CI integration.



Testing & debugging tips:

Use medalign.data.synthetic_dataset for deterministic tests of data pipeline and model forward passes.

To debug numerical issues, set torch.set_printoptions(profile="full") and run on CPU with small batch sizes.

Use the scripts/run_client_sim.py to validate the client pipeline in isolation before full federated runs.

Log intermediate tensor shapes in models.py / client.py with logger.debug(...) to trace shape mismatches.