# MAL
This repository is the official implementation of our  [**Cluster-Masked Scanning and Pretraining for Enhanced xLSTM Vision Performance**]

![img](overall.jpg)

## Introduction
While modern recurrent architectures like xLSTM show promise for vision tasks, their potential has been hindered by the challenge of effectively applying autoregressive pretraining---a cornerstone of NLP success---to 2D image data. This paper introduces MAL, a framework that unlocks autoregressive learning for vision-oriented xLSTMs. Our core innovation is a cluster-masked pretraining strategy, which reorganizes an image into a sequence of semantically meaningful local clusters. This approach creates a more structured input sequence uniquely suited to xLSTM's memory mechanisms. By combining this with our novel cluster scanning strategy which defines an optimal processing order, MAL effectively learns powerful visual representations by predicting entire image regions autoregressively. Our experiments show that this novel pretraining scheme allows MAL to significantly outperform traditional supervised models, fully leveraging the scaling potential of xLSTM and setting a new performance benchmark.
## Installation
This repository was built using Python 3.10. 
such as ** pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 **
[can be found here](https://pytorch.org/get-started/locally/). 

```bash
conda create -n MAL python=3.10
conda activate MAL
pip install numpy==2.1.3 tensorboard==2.18.0 timm==0.4.12 einops==0.8.0

export PYTHONPATH=$PYTHONPATH:/your_program_path/MAL
```

If you run into any problems during the installation process, please file a GitHub Issue.




## Pretraining
```bash
torchrun --standalone --nproc_per_node=8 --master_port 1221 main_pretrain.py \
--batch_size 512 \
    --model mal_base_pz16 \
    --norm_pix_loss \
    --epochs 800 \
    --warmup_epochs 40 \
    --blr 5e-4 --weight_decay 0.05 \
    --data_path /path/to/ImageNet/ --output_dir ./out_b/
```

## Finetuning
```bash
cd Finetuning
torchrun --standalone --nproc_per_node 8 main_finetune.py --batch_size 1024 \
    --model mal_base_pz16 --finetune ./checkpoint-base.pth \
    --epochs 200 --global_pool True \
    --blr 5e-4 --layer_decay 0.65 --ema_decay 0.99992 \
    --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
    --dist_eval --data_path /path/to/ImageNet --output_dir ./out_finetune_b/
```

## Evaluation
```bash
cd Finetuning
torchrun --standalone --nproc_per_node 8 main_finetune.py --batch_size 1024 \
    --model mal_base_pz16 --finetune ./checkpoint-base.pth \
    -data_path /path/to/ImageNet --eval True
```


## Checkpoint
The pretrained models are available at [[huggingface🤗](https://huggingface.co/anonymous405/MAL/tree/main)].

