# UdonCare

This folder contains the PyTorch implementation of **UdonCare**, our hierarchy-guided, mutual-learning model for robust EHR prediction. It defines the backbone encoder, feature encoders, domain discovery modules, prediction heads, and the top-level model wrapper.

## Files

- **`backbone.py`**  
  Transformer components (attention, feed-forward, residual connections) and `TransformerLayer`, used as the feature extractor for sequential patient visits.

- **`encoders.py`**  
  Per-key encoders (`CodeKeyEncoder`) for **Diseases**, **Procedures**, and **Drugs**. Each uses the backbone Transformer to produce patient-level embeddings.

- **`heads.py`**  
  Simple MLP classification head (`LabelHead`) applied to patient embeddings, used for both backbone and invariant paths.

- **`domain.py`**  
  Implements hierarchy-guided domain discovery and domain-invariant representation learning:  
  - `HierarchyPruner` for pruning ICD disease hierarchy into latent domains.  
  - `DomainEncoder` for mapping domain IDs into embeddings.  
  - `invariant_projection` for orthogonal projection to remove domain-specific components.

- **`udoncare.py`**  
  Main `UdonCare` class that ties everything together. It defines the backbone path and the domain-invariant path with mutual learning between them.

## Notes

- Inputs are restricted to **Diseases**, **Procedures**, and **Drugs** codes from MIMIC datasets.  
- Domain discovery is based on the **disease hierarchy** (ICD-9).  
- Final predictions come from the **domain-invariant path**.

## Open-Source Commitment

We will **open-source the complete codebase**, including this `models/` directory, **once the paper is accepted**.
