# C-MAE: A PyTorch Implementation

**Supplemental Material for ICLR 2026 Submission**

This is a PyTorch implementation of the paper `Benchmarking Self-Supervised Vision Transformer in Astronomy`.

* The implementation mainly consists of two parts: `pre-training` and `fine-tuning`.


* We also show demos of model evaluation (classification and object detection) and redshift prediction. See `demo_accuracy_eval.ipynb`, `dome_detection_eval.ipynb` and `demo_redshift_eval.ipynb` for details.

### 0. Environments

- PyTorch 1.10.0
- torchvision 0.11.0
- timm 0.6.0
- Detectron 2 (for detection on neuralens-desi)

### 1. Pre-training

- Dataset

    - The file [./Dataset/DESINet.py](./Dataset//DESINet.py) implements pre-train data image reading and preprocessing. The default training dataset is `DESI-1M (DESINet_train_100.txt)`.

- Model

    - The file [models_mae.py](./models_mae.py) implements models definition.

- Training

    - To pre-train model with multi-node distributed training, run the following on 8 nodes with 8 GPUs each:
    - To train ViT-Base or ViT-Huge, set `--model_name mae_vit_base_patch16` or `--model_name mae_vit_huge_patch14`.
    ```bash
    torchrun --nproc_per_node=8 \
         --nnodes=8 \
         --node_rank=$RANK \
         --master_addr=$MASTER_ADDR \
         pretrain.py \
         --model_name mae_vit_large_patch16 \
         --dataset=$dataset \
         --mask_ratio 0.75 \
         --epoch 800 \
         --batch_size 4096 \
         --blr 2.0e-4 \
         --min_lr 0.0 \
         --warmup_epochs 40 \
         --weight_decay 0.05 \
         --save_dir ${SAVE_PATH}
    ```
    

### 2. Fine-tuning

#### 2.1 galaxy-desi
  ```bash
  torchrun --nproc_per_node=4 \
         --master_port=23451 \
         ft_galaxy_desi.py \
         --model_name vit_large_patch16 \
         --finetune ${Large ckpt} \
         --drop_path 0.1 \
         --epochs 50 \
         --batch_size 64 \
         --weight_decay 0.5 \
         --lr 2.0e-3 \
         --min_lr 1e-8 \
         --layer_decay 0.75 \
         --warmup_epochs 5 \
         --model_ema \
         --save_dir ${SAVE_PATH}
  ```
- The training scripts for vit_base and vit_huge are similar, please refer to [ft_galaxy_desi.sh](./ft_galaxy_desi.sh).


#### 2.2 galaxy-sdss

  ```bash
  torchrun --nproc_per_node=4 \
         --master_port=23459 \
         ft_galaxy_sdss.py \
         --model_name vit_base_patch16 \
         --finetune ${Base ckpt} \
         --drop_path 0.1 \
         --epochs 50 \
         --batch_size 64 \
         --weight_decay 0.3 \
         --lr 1.5e-3 \
         --min_lr 1e-8 \
         --layer_decay 0.65 \
         --warmup_epochs 5 \
         --model_ema \
         --save_dir ${SAVE_PATH}
  ```

  - For other backbones, please refer to [ft_galaxy_sdss.sh](./ft_galaxy_sdss.sh).



#### 2.3 neuralens-desi

  ```bash
  python3 lazyconfig_train_net.py --config-file ./config/mask_rcnn_vit_b.py --num-gpus 8
  ```

  - Before training, please refer to the Detectron 2 tutorial to prepare the dataset correctly. The file [./config/mask_rcnn_vit_b.py](./config/mask_rcnn_vit_b.py) shows an example.



#### 2.4 redshift-sdss

  ```bash
  torchrun --nproc_per_node=4 \
         --master_port=23459 \
         ft_redshift_sdss.py \
         --model_name vit_base_patch16 \
         --finetune ${Base ckpt} \
         --drop_path 0.1 \
         --epochs 50 \
         --batch_size 64 \
         --weight_decay 0.5 \
         --lr 1.5e-3 \
         --min_lr 1e-8 \
         --layer_decay 0.65 \
         --warmup_epochs 5 \
         --model_ema \
         --save_dir ${SAVE_PATH}
  ```
- For other backbones, please refer to [ft_redshift_sdss.sh](./ft_redshift_sdss.sh).