# SegVol: Universal and Interactive Volumetric Medical Image Segmentation
The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point**, **box** and **text** prompt while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.

We have released SegVol's **inference code**, **training code**, **model params** and **ViT pre-training params** (pre-training is performed over 2,000 epochs on 96k  CTs). 

**Keywords**: 3D medical SAM, volumetric image segmentation

## Start with source code
### Requirements
The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or a higher version) is needed first. Following install key requirements using commands:

```
pip install 'monai[all]==0.9.0'
pip install einops==0.6.1
pip install transformers==4.18.0
pip install matplotlib
```

### Guideline for training and inference

#### How to infer a demo case

1. You can download the demo dataset  [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) and choose any demo case you want.
2. Please set CT path and Ground Truth path of the case in the *config_demo.json*.
3. After that, config the *inference_demo.sh* for execution:

    - `$segvol_ckpt`: the path of SegVol's checkpoint.

    - `$work_dir`: any path of folder you want to save the log files and visualizaion results.
4. Finally, you can control the **prompt type**, **zoom-in-zoom-out mechanism** and **visualizaion switch** in *inference_demo.py*.
5. Now, just run `bash script/inference_demo.sh` to infer your demo case.

#### How to train SegVol

##### Build universal datasets

1. We use the [Abdomenct-12organ](https://zenodo.org/records/7860267) as demo dataset. 
2. After downloading the demo dataset, you need to config the *script/build_dataset.sh* file to set the environment vars:
    * `$SAVE_ROOT` is the save path for the post-processed datasets.
    * `$DATASET_CODE`  is your custom id for your dataset. We suggest you use  `0000`, `0001`, ... as the dataset  id.
    * `$IMAGE_DIR` and `$LABEL_DIR` is the image directory path and label directory path of the original demo dataset.
    * `$TEST_RATIO` is the ratio of preserved val/test data from the whole set.
3. **Set the `category` in *data_process/train_data_process.py*.** Categories should be in the same order as the corresponding idx in ground truth volume and `background` category should be ignored.
4. Just run `bash script/build_dataset.sh`.

If you want to combine **multiple datasets**, you can run the *script/build_dataset.sh* for multiple times and assign different `$DATASET_CODE` for each dataset.

##### Build pseudo mask labels

After the process of building universal datasets finished, you should build pseudo mask labels for each CT in the post-processed datasets.

1. You will need to config the *script/build_pseudo_mask.sh* first:
    * `$DATASET_ROOT` is the directory path for the post-processed datasets.
    * `$DATASET_CODE` is the custom code of your post-processed dataset.
2. Run `bash script/build_pseudo_mask.sh`. The pseudo masks for the `$DATASET_CODE` dataset will be generated at `$DATASET_ROOT/$DATASET_CODE/fh_seg`.

If you combine **multiple datasets**, you should run the *script/build_pseudo_mask.sh* for each dataset.

##### Training

1. Make sure you have completed the above steps correctly.
2. Set environment vars in *script/train.sh*:
    * `$SEGVOL_CKPT` is the weight file of SegVol.
    * `$WORK_DIR` is save path for log files and checkpoint files in the training phase.
    * `$DATA_DIR` is the directory path for the above post-processed datasets.
    * Define *dataset_codes* to indicate which datasets are used for training
    * Configure hyper parameters according to your training needs.
    * Set the `$CUDA_VISIBLE_DEVICES` according to  your devices.
3. Run `bash script/train.sh`.

##### Training from scratch

If you want to training from scratch without our SegVol checkpoint, I highly recommend that you use the our pre-trained ViT and load the CLIP TextEncoder parameters.

#### How to use our pre-trained ViT as your model encoder

We pre-train ViT on 96k CTs for 2,000 epochs. The pre-trained ViT shows excellent generalization performance and the ability to accelerate convergence. 
A simple experiment is performed on [AMOS22](https://amos22.grand-challenge.org/), training [UNETR](https://arxiv.org/abs/2103.10504) with and without pre-trained encoder:

| Model |    Encoder    | Dice score(%) |
| :---: | :-----------: | :-----------: |
| UNETR | w/o pre-train |     67.12     |
| UNETR | w   pretrain  |     79.10     |


You can use the ViT independently as your model's encoder. The demo code is as follows:

```python
import torch
from monai.networks.nets import ViT

vit_checkpoint = 'path/to/ViT_pretrain.ckpt'

vit = ViT(
        in_channels=1,
        img_size=(32,256,256),
        patch_size=(4,16,16),
        pos_embed="perceptron",
        )
print(vit)

with open(vit_checkpoint, "rb") as f:
    state_dict = torch.load(f, map_location='cpu')['state_dict']
    encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
vit.load_state_dict(encoder_dict)
print(f'Image_encoder load param: {vit_checkpoint}')
```

### Datasets involved

Links to the original datasets:
| Dataset  | Link |
| ------------- | ------------- |
| 3D-IRCADB  | https://www.kaggle.com/datasets/nguyenhoainam27/3dircadb |
|AbdomenCT-1k|	https://github.com/JunMa11/AbdomenCT-1K|
|AMOS22|	https://amos22.grand-challenge.org/|
|BTCV|	https://www.synapse.org/\#!Synapse:syn3193805/wiki/217752|
|CHAOS|	https://chaos.grand-challenge.org/|
|CT-ORG|	https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=61080890|
|FLARE22|	https://flare22.grand-challenge.org/|
|HaN-Seg|	https://han-seg2023.grand-challenge.org/|
|KiPA22|	https://kipa22.grand-challenge.org/|
|KiTS19|	https://kits19.grand-challenge.org/|
|KiTS23|	https://kits-challenge.org/kits23/|
|LUNA16|	https://luna16.grand-challenge.org/Data/|
|MSD-Colon|	http://medicaldecathlon.com/|
|MSD-HepaticVessel|	http://medicaldecathlon.com/|
|MSD-Liver|	http://medicaldecathlon.com/|
|MSD-lung|  	http://medicaldecathlon.com/|
|MSD-pancreas|	http://medicaldecathlon.com/|
|MSD-spleen|	http://medicaldecathlon.com/|
|Pancreas-CT|	https://wiki.cancerimagingarchive.net/display/public/pancreas-ct|
|QUBIQ|	https://qubiq.grand-challenge.org/|
|SLIVER07|	https://sliver07.grand-challenge.org/|
|TotalSegmentator|	https://github.com/wasserth/TotalSegmentator|
|ULS23|	https://uls23.grand-challenge.org/|
|VerSe19|	https://osf.io/nqjyw/|
|VerSe20|	https://osf.io/t98fz/|
|WORD|	https://paperswithcode.com/dataset/word|

Thanks for the following amazing works:

[HuggingFace](https://huggingface.co/).

[CLIP](https://github.com/openai/CLIP).

[MONAI](https://github.com/Project-MONAI/MONAI).

[3D Slicer](https://www.slicer.org/).

[Image by brgfx](https://www.freepik.com/free-vector/anatomical-structure-human-bodies_26353260.htm) on Freepik.

[Image by muammark](https://www.freepik.com/free-vector/people-icon-collection_1157380.htm#query=user&position=2&from_view=search&track=sph) on Freepik.



