# Mask What Matters: Controllable Text-Guided Masking for Self-Supervised Medical Image Analysis

This repository contains the implementation of MWM, a text-driven self-supervised learning framework that integrates region-aware masked image modeling with natural language guidance. The project builds upon SparK and supports pretraining, localization, and downstream transfer tasks such as classification, detection, and segmentation.

> ⚠️ Note: The code is currently under organization. A cleaned and complete release will follow shortly.

---

## 1. Environment Setup

We recommend using `conda` to manage the Python environment:

```bash
conda create -n spark python=3.8 -y
conda activate spark
```

---

## 2. Dependencies

Install PyTorch, `timm`, and other dependencies:

```bash
pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install timm==0.5.4
pip install -r requirements.txt
```

> ✅ Please use the recommended versions to ensure consistent and reproducible results.

---

## 3. Dataset Preparation

### ImageNet-1k (for large-scale pretraining)

```
/path/to/imagenet/
├── train/
│   ├── class1/
│   │   └── *.jpeg
│   └── class2/
│       └── *.jpeg
└── val/
    ├── class1/
    │   └── *.jpeg
    └── class2/
        └── *.jpeg
```

Pass the dataset path via `--data_path`:

```bash
python main.py --data_path /path/to/imagenet ...
```

### Medical Pretraining (e.g., Chest CT)

```
/path/to/your_data/
├── images/
│   └── *.png or *.nii.gz
└── masks_expanded/
    └── *.png or *.npy  # semantic masks generated via text2region
```

---

## 4. Project Structure

```
MWM/
├── pretrain/              # Pretraining logic and main training scripts
├── text2region/           # Text-guided region localization module
├── downstream_imagenet/   # Classification tasks on ImageNet-style datasets
├── downstream_d2/         # Object detection and segmentation via Detectron2
├── LICENSE
└── INSTALL.md             # This README
```

---

## 5. Module Descriptions

### 🔍 `text2region/` — Text-Guided Region Localization

This module performs semantic region localization based on natural language prompts. It leverages pretrained vision-language models  to identify task-relevant areas for use in self-supervised training.

---

### ⚙️ `pretrain/` — Region-Aware Masked Pretraining

This folder contains the core training code for region-aware masked image modeling. You can launch full pretraining with:

```bash
python run_targeted_masking.py
```
The script interactively supports:
- Random or text-guided masking
- Setting masking ratios for salient and non-salient regions
- Optional experiment name, output directory, pretrained initialization
- Distributed training via `torchrun` if desired
---

### 📦 `downstream_d2/` — Detection and Segmentation

Implements downstream evaluation using [Detectron2](https://github.com/facebookresearch/detectron2). Suitable for object detection or instance segmentation tasks on medical or natural datasets.

Example:
```bash
python downstream_d2/train_net.py \
  --config-file configs/ct_tumor.yaml \
  --opts MODEL.WEIGHTS /path/to/pretrained_model.pth
```

---

### 🧪 `downstream_imagenet/` — Image Classification Tasks

This module supports fine-tuning or linear probing on ImageNet-style datasets. You can run standard classification pipelines using features extracted from pretrained encoders.

---
