# Hierarchical Cautious Optimization (HCO) for Semi-supervised Medical Image Segmentation

This repository contains the implementation for **Study 3** of our paper:

**"Hierarchical Cautious Optimization for Semi-supervised Segmentation of Volumetric Medical Images"**

## Study 3: CMT on the Fetal-MRI Dataset

This study evaluates whether HCO generalizes across semi-supervised learning frameworks, optimizer types, and anatomical targets by applying AdamW₍hco₎ within the collaborative mean teacher (CMT) framework for fetal lung and liver segmentation.


### Key Results

AdamW₍hco₎ demonstrated substantial improvements over standard AdamW:

**For fetal lungs:**
- Dice increased by 24.6% (67.54→84.17)
- Jaccard increased by 39.3% (52.80→73.54)
- ASD decreased by 56.2% (8.12→3.56)
- HD95 decreased by 47.0% (25.09→13.30)

**For fetal liver:**
- Dice increased by 15.7% (70.99→82.15)
- Jaccard increased by 25.3% (56.10→70.32)
- ASD decreased by 54.9% (10.95→4.94)
- HD95 decreased by 52.2% (31.80→15.21)

All improvements were statistically significant (paired t-test, p < 1×10⁻¹⁸).

## Requirements

The code is developed using Python 3.7 with PyTorch 1.11.0.
All experiments were conducted on a single GPU with 24GB memory.
Install dependencies by creating a Conda environment from the `environment.yaml` file:
```bash
conda env create -f environment.yaml
conda activate <your_env_name>
```

## Dataset

The experiments in Study 3 use an in-house fetal body MRI dataset consisting of:
- 92 manually-labeled scans
- 600 unlabeled scans
- Gestational age: 28–39 weeks
- 3T Siemens True Fast Imaging with Steady-State Precession (TRUFI) sequence acquisitions
- Voxel size: 0.78×0.78×2 mm

For the experiments, we held out:
- 50 labeled cases for testing
- 2 labeled cases for validation
- The remaining labeled cases were split into 4 subsets (L₀-L₃, 5 scans each)
- The unlabeled data was split into 4 subsets (U₀-U₃, 150 scans each)

## Training Protocol

Following the CMT framework, we employ a 3D V-Net backbone. Data augmentation consists of random flips and 3D crops. Each training iteration randomly samples eight patches of size 144×144×64 (4 labeled, 4 unlabeled). AdamW and its HCO variant are used with a fixed learning rate of 1×10⁻⁴. Training runs for 200 epochs.

To avoid bias from a single labeled/unlabeled split, we generate 4×4=16 combinations: four labeled subsets (L₀-L₃, 5 scans each) and four unlabeled subsets (U₀-U₃, 150 scans each). CMT is trained from scratch for every pair (Lᵢ,Uⱼ). Each run took approximately 4 hours, for a total compute budget of about 256 GPU-hours.

## Code Structure

- `train_hco.py`: Implementation of the HCO method for training
- `train_bl.py`: Implementation of baseline methods (AdamW and CO)
- `cadam.py`: Implementation of the AdamWHCO optimizer
- `run_inference.py`: Script for inference on test data
- `calc_metrics_asd.py`: Script for calculating evaluation metrics
- `models/`: Contains model architectures (V-Net backbone)
- `data/`: Contains dataset loaders and preprocessing functions
- `utils/`: Contains utility functions for training and evaluation

## Usage

### Training

#### AdamW Baseline

```bash
python train_co.py \
    --project /path/to/output/cadam_one_step_outputs_final_baselines/ \
    --data_path /path/to/data/labels/train_labeled_5_split_0.json \
    --unlabelled_json_path /path/to/data/unlabelled_split_2.json \
    --num_epochs 200 
```

#### CO (Cautious Optimization) Baseline

```bash
python train_bl.py \
    --project /path/to/output/cadam_one_step_outputs_final_baselines/ \
    --data_path /path/to/data/labels/train_labeled_5_split_0.json \
    --unlabelled_json_path /path/to/data/unlabelled_split_2.json \
    --num_epochs 200 \
    --caution 
```

#### HCO (Hierarchical Cautious Optimization)

```bash
python train_hco.py \
    --project /path/to/output/cadam_one_step_outputs_final/ \
    --data_path /path/to/data/labels/train_labeled_5_split_0.json \
    --unlabelled_json_path /path/to/data/unlabelled_split_3.json \
    --num_epochs 200 \
    --unlabeled_momentum_update 
```

#### Ablation Study (without unlabeled momentum update)

```bash
python train_hco.py \
    --project /path/to/output/cadam_one_step_outputs_final/ \
    --data_path /path/to/data/labels/train_labeled_5_split_0.json \
    --unlabelled_json_path /path/to/data/unlabelled_split_3.json \
    --num_epochs 200
```

### Inference

To run inference on test data:

```bash
python run_inference.py --model_weights /path/to/output/runs/EXP/weights/best.pth --dest_path cadam_outputs/cadam_train_5_split_0_ul150_split_1_unlabeled
```

### Evaluation

To calculate metrics (Dice, Jaccard, ASD, HD95):

```bash
python calc_metrics_asd.py \
    --preds "folder_path" \
    --output_csv "csv_output"
```


## Acknowledgments

This implementation builds on:
* [TorchSemiSeg](https://github.com/charlesCXK/TorchSemiSeg) 
* [UA-MT](https://github.com/yulequan/UA-MT)
* [SSL4MIS](https://github.com/HiLab-git/SSL4MIS)
* [UCMT](https://github.com/Senyh/UCMT)

