
# MMaDA 

## 🔧 Training
**Update your training data path in `configs/xx.yaml`.**

### Stage 0. Prepare your accelerate configs
Please first prepare your accelerate configs. You can simple run
```
accelerate config
```

Or use our provided configs in `accelerate_configs`:
```
├── accelerate_configs/ 
|   ├── 1_gpu.yaml
|   └── 8_node_8_gpus_deepspeed_zero2.yaml (for 8 * 8 gpus)
```

### Stage 1.1: Pre-training on ImageNet
First we use LLaDA-8B-Instruct to initialize our model, and train on ImageNet for basic visual capbalities. 
```
accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada.py config=configs/mmada_pretraining_stage1_llada_instruct.yaml
```

### Stage 1.2 Pre-training on Image-Text Dataset
Then we replace the ImageNet dataset in Stage 1.1 with Image-Text Dataset. Please change the pretrained model path in `mmada_pretraining_stage2_llada_instruct.yaml` with your checkpoint in Stage 1.1
```
accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage2.py config=configs/mmada_pretraining_stage2_llada_instruct.yaml
```

### Stage 1.3 Pre-training on Text Instruction following
In this stage, we begin training on text instruction following and include corresponding validations. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 1.2
```
accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage3.py config=configs/mmada_pretraining_stage3_llada_instruct.yaml
```

### Stage 2.1 Mix-CoT Training (Text Only)
In this stage, we begin our Mix-CoT finetuning with text reasoning first, along with improved image quality. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 1.3 and prepare your CoT data.
```
accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage_cot_sft.py config=configs/mmada_pretraining_stage3_llada_instruct_512_cot.yaml
```

### Stage 2.2 Mix-CoT Training (with MultiModal Reasoning)
In this stage, we include multimodal reasoning, along with improved image quality. Please change the pretrained model path in `mmada_pretraining_stage3_llada_instruct.yaml` with your checkpoint in Stage 2.1 and prepare your CoT data.
```
accelerate launch --config_file path/to/your/accelerate_config --main_process_port=8888 training/train_mmada_stage4.py config=configs/mmada_pretraining_stage4_llada_instruct.yaml
```

### Stage 3 UniGRPO RL
This part of code will be released upon acceptance.



## 🚀 Inference
For batch-level inference, we provide our inference scripts here.
### 1. Text Generation
For text generation, we follow LLaDA's configuration and generation script. Simple run:
```bash
python generate.py
```

### 2. MultiModal Generation
For multimodal generation and text-to-image generation, first login your wandb account:
```
wandb login
```
Inference demo for MultiModal Generation and you can view the results on wandb:
```
python3 inference_mmu.py config=configs/mmada_demo.yaml mmu_image_root=./mmu_validation question='Please describe this image in detail.' 
```

### 3. Text-to-Image Genertion
For multimodal generation and text-to-image generation, first login your wandb account:
```
wandb login
```
Inference demo for Text-to-Image Genertion and you can view the results on wandb:
```
python3 inference_t2i.py config=configs/mmada_demo.yaml batch_size=1 validation_prompts_file=validation_prompts/text2image_prompts.txt guidance_scale=3.5 generation_timesteps=15
mode='t2i'
```