<div align="center">
    <h2>
        Diffusion Domain Teacher: Diffusion Guided Domain Adaptive Object Detector
    </h2>
</div>
<br>

## Introduction

This repository is the code implementation of the paper **Diffusion Domain Teacher: Diffusion Guided Domain Adaptive Object Detector** , which is based on the [MMDetection](https://github.com/open-mmlab/mmdetection) project.

  Object detectors often suffer a decrease in performance due to the large domain gap between the training data (source domain) and real-world data (target domain). 
  Diffusion-based generative models have shown remarkable abilities in generating high-quality and diverse images, suggesting their potential for extracting valuable feature from various domains. 
  To effectively leverage the cross-domain feature representation of diffusion models, in this paper, we train a detector with frozen-weight diffusion model on the source domain, 
  then employ it as a teacher model to generate pseudo labels on the unlabeled target domain, which are used to guide the supervised learning of the student model on the target domain. 
  We refer to this approach as Diffusion Domain Teacher (DDT). By employing this straightforward yet potent framework, w
  e significantly improve cross-domain object detection performance without compromising the inference speed. 
  Our method achieved an average mAP improvement of 21.2% compared to the baseline on 6 datasets from three common cross-domain detection benchmarks (Cross-Camera, Syn2Real, Real2Artistic), 
  surpassing the current state-of-the-art (SOTA) methods by an average of 5.7% mAP. Furthermore, extensive experiments demonstrate 
  that our method consistently brings improvements even in more powerful and complex models, such as large backbones, self-supervised, and large pre-trained models, highlighting broadly applicable and effective domain adaptation capability of our DDT.

## Installation
### Requirements
- Linux system, Windows is not tested
- Python 3.8+, recommended 3.11
- PyTorch 2.0 or higher, recommended 2.0.0
- CUDA 11.7 or higher, recommended 11.8
- MMCV 2.0 or higher, recommended 2.0.0
- MMDetection 3.0 or higher, recommended 3.3.0
- diffusers 0.20.0 or higher, recommended 0.20.2
### Environment Installation

It is recommended to use conda for installation. The following commands will create a virtual environment named `DDT` and install PyTorch and MMCV. In the following installation steps, the default installed CUDA version is **11.8**. 
If your CUDA version is not 11.8, please modify it according to the actual situation.
Note: If you are experienced with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow the steps below.

**Step 1**: Create a virtual environment named `DDT` and activate it.

```shell
conda create -n DDT python=3.11 -y
conda activate DDT
```

**Step 2**: Install [PyTorch2.x](https://pytorch.org/get-started/locally/).

Linux/Windows:
```shell
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
```

**Step 3**: Install [MMDetection-3.x](https://mmdetection.readthedocs.io/en/latest/get_started.html).

```shell
pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.0"
mim install mmdet=3.3.0
```

**Step 4**: Prepare for [Stable-diffusion-1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) with diffusers

```shell
git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
```
And you should move the **Stable-diffusion-1.5** to the same dir as our **DDT**. Then:

```shell
pip install diffusers==0.20.2
```

## Dataset Preparation

### Cross-domian detection datasets

We provide the cross-domain detection dataset used in the paper. To facilitate the fast implementation of code, we provide Coco-style json files for all datasets [google drive](https://drive.google.com/drive/folders/1D1UMj79BdUBW1Jppl0R9OxUFKrgi3Eei?usp=drive_link).

- Image and annotation download link: [Cityscapes](https://www.cityscapes-dataset.com).
- Image and annotation download link: [BDD 100k](https://bdd-data.berkeley.edu/).
- Image and annotation download link: [SIM10k](https://fcav.engin.umich.edu/projects/driving-in-the-matrix).
- Image and annotation download link: [VOC 07+12](http://host.robots.ox.ac.uk/pascal/VOC/).
- Image and annotation download link: [Clipart, Comic, Watercolor](https://github.com/naoto0804/cross-domain-detection/tree/master/datasets).
## Code of our DDT

**Important code directories**：

- `DA`：The root directory of our config file for DDT.
- `DA/_base_/da_setting`：Training and lr config of DDT.
- `DA/_base_/datasets`：Six cross-domain datasets config of DDT.
- `DA/Ours`：Detector config on six cross-domian datastes of DDT.
- [`mmdet/models/backbones/dift_encoder.py`](mmdet/models/backbones/dift_encoder.py)：Code and setting of  diffusion backbone.
- [`mmdet/models/detectors/Z_dift_semi_base.py`](mmdet/models/detectors/Z_dift_semi_base.py)：Code of diffusion teacher for self-training.
- [`mmdet/models/detectors/Z_domain_adaptation_detector.py`](mmdet/models/detectors/Z_domain_adaptation_detector.py)：Main code of DDT.
- [`mmdet/engine/hooks/adaptive_teacher_hook.py`](mmdet/engine/hooks/adaptive_teacher_hook.py)：Training and testing hooks of DDT.

## Model Training

### Diffusion detector training
The models are trained for 20,000 steps on two 3090 GPUs, with a batch size of 16. 
If your settings are different from ours, please modify the training steps and default learning rate settings in [training config](DA/_base_/da_setting/da_20k_0.1backbone.py).
Or You can use the trained models that we provide [google drive](https://drive.google.com/drive/folders/11YnyZLGXqJmgwptBzZRoRO3lvdHBudPw?usp=drive_link)
#### Multi-gpu Training
```shell
sh ./tools/dist_train.sh ${CHECKPOINT_FILE} ${GPU_NUM}  # CHECKPOINT_FILE is the configuration file you want to use, GPU_NUM is the number of GPUs used
```
For example:
```shell
sh ./tools/dist_train.sh DA/Ours/city_to_bdd100k/diffusion_faster-rcnn_fpn_city_to_bdd100k_source.py  2  
```

### DDT training
The models are trained for 20,000 steps on two 3090 GPUs, with a batch size of 16.
We train exclusively on the source domain for the first 12000 steps and then perform joint training on both the source and target domain for the remaining 8000 steps.

If your settings are different from ours, please modify the training steps and default learning rate settings in [training config](DA/_base_/da_setting/semi_e2e_20k_0.1backbone.py).
You should first modify the config and weights paths [here](DA/Ours/city_to_bdd100k/DDT_r101_fpn_city_to_bdd100k.py), ensuring that the differential detector being used matches your location.

For example:
```shell
sh ./tools/dist_train.sh DA/Ours/city_to_bdd100k/DDT_r101_fpn_city_to_bdd100k.py 2 
```
## Model Testing
#### Multi-gpu Testing：

Note: Please change the code [here](DA/Ours/sim10k_to_bdd100k/DDT_r101_fpn_sim10k_to_bdd100k.py) ***detector.dift_model.config*** and ***detector.dift_model.pretrained_model*** as None before test.

```shell
sh ./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM}  # CONFIG_FILE is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, GPU_NUM is the number of GPUs used
```

## Trained models

- **Cityscapes to BDD100k:** [google drive](https://drive.google.com/drive/folders/11YnyZLGXqJmgwptBzZRoRO3lvdHBudPw?usp=drive_link)
- **Sim10k to BDD 100k:** [google drive](https://drive.google.com/drive/folders/11YnyZLGXqJmgwptBzZRoRO3lvdHBudPw?usp=drive_link)
- **Sim10k to Cityscapes:** [google drive](https://drive.google.com/drive/folders/11YnyZLGXqJmgwptBzZRoRO3lvdHBudPw?usp=drive_link)
- **VOC to Clipart:** [google drive](https://drive.google.com/drive/folders/11YnyZLGXqJmgwptBzZRoRO3lvdHBudPw?usp=drive_link)
- **VOC to Comic:** [google drive](https://drive.google.com/drive/folders/11YnyZLGXqJmgwptBzZRoRO3lvdHBudPw?usp=drive_link)
- **VOC to Watercolor:** [google drive](https://drive.google.com/drive/folders/11YnyZLGXqJmgwptBzZRoRO3lvdHBudPw?usp=drive_link)

