# Temporal Treatment Outcome Modeling with Causal Diffusion Models

## Environment Setup

1.  **Create the Conda Environment**:
    ```bash
    conda env create -f environment.yml python=3.12.8
    ```
2.  **Activate the Environment**:
    ```bash
    conda activate diff
    ```
3.  **Prepare the Dataset from OAI**:
    Create an account with https://nda.nih.gov/oai and paste username and password into `download_oai_data.py`. Download the OAI archive at https://nda.nih.gov/oai/full_downloads. To download images, visit https://nda.nih.gov/user/dashboard/packages and find IDs for every image package and replace `packageId` in `download_oai_data.py`. Run
    ```bash
    python download_oai_data.py
    ```
    with every available package ID. 
    
    Finally, create full dataset CSV with:
    ```bash
    python create_dataset.py
    ```

## Data Preparation

1.  **Create Window Dataset Data Splits**  
   Navigate to the `pairs_dataset` directory and run:
   ```bash
   python create-pairs-dataset.py
   ```
   This will create `pairs-dataset.csv` along with the splits. 
   The exact splits are also included in the zip file. 

2.  **Crop Knee X-Rays **  
   Navigate to the `yolo` directory and replace `csv_path` with path to `pairs-dataset.csv` and run:
   ```bash
   python create_yolo_dataset_parallel.py
   ```
   This will create a new folder with Cropped Knee X-Rays.

## Pretraining Steps

### 1. X-Ray Grade Predictors 

These models predict specific X-ray grades (e.g., Kellgren-Lawrence grade, Joint Space Narrowing).

Before training the diffusion models, pretrain the X-ray grade predictors for each feature listed in `PERCEPTUAL_FEATURES` (defined in `data/pairs_dataset/dataset.py`). For example:

```bash
python pretrain_xray_grade_model.py --feature_name KLGrade --epochs 50 --batch_size 128
python pretrain_xray_grade_model.py --feature_name JSN_Medial --epochs 50 --batch_size 128
```
Ensure `PERCEPTUAL_FEATURES` in `data/pairs_dataset/dataset.py` is set correctly:
```python
PERCEPTUAL_FEATURES = [
    "V00XRKL",    # Example: Kellgren-Lawrence Grade
    "V00XRJSM",   # Example: Medial Joint Space Narrowing
]
```
### 2. Temporal Propensity Model (for IPW Pipeline)

This model is pretrained to estimate treatment propensities based on historical sequences and current patient/image context. It is used exclusively by the `diff_ipw.py` pipeline.

Pretrain the temporal propensity model:
```bash
python pretrain_propensity_model_temporal.py
```
## Training Diffusion Models

Both diffusion pipelines are conditioned using a `ContextEncoderRNN` that processes:
-   Visual features from the current input X-ray (`X_t`).
-   Historical sequences of covariates (`cov_seq_hist`).
-   Historical sequences of treatments (`trt_seq_hist`).
-   Time delta to the next scan (`delta_t`).
-   Side of the knee (`side`).

They also compute a perceptual loss based on pretrained X-ray grade predictors.

---
### 1. Adversarial Diffusion Model (`diff_adversarial.py`)

This model incorporates an adversarial loss to encourage the generator to produce images where the applied treatment cannot be easily inferred from the generated outcome and context, thus aiming for treatment-invariant representations.

**a. Train with Adversarial Loss:**
```bash
python diff_adversarial.py 
```
Evaluate with: 
```bash
python diff_adversarial.py --test_only --ckpt_path <PATH_TO_CHECKPOINT>
```
**b. Train Baseline (Feature-Informed, Non-Adversarial):**
To run this pipeline as a standard feature-informed diffusion model (without the adversarial component), set `--adversarial_weight 0`:
```bash
python diff_adversarial.py --adversarial_weight 0
```
Evaluate with: 
```bash
python diff_adversarial.py --test_only --ckpt_path <PATH_TO_CHECKPOINT>
```
**c. Train IPW Diffusion Model:**
Once the propensity model is pretrained, run:
   ```bash
   python diff_ipw.py
   ```
   This script loads the pretrained propensity model and uses it to compute inverse propensity weights (IPW) to reweight the diffusion loss during training.  
Evaluate with:
```bash
python diff_ipw.py --test_only --ckpt_path <PATH_TO_CHECKPOINT>
```

### 2. TIDAL

Run the python file:
```bash
python diff_adversarial_ipw2_refactored.py
```